未验证 提交 731d45a3 编写于 作者: Q qingqing01 提交者: GitHub

Enable BatchNorm to use global mean and variane during training (#14630)

* Enable BatchNorm to use global mean and variane during training
* Update doc and follow comments.
上级 400cf19f
......@@ -69,7 +69,7 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name']
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
......
......@@ -146,7 +146,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
bool global_stats = is_test || use_global_stats;
const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean");
......@@ -177,13 +179,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
T *batch_mean_data = nullptr;
T *batch_variance_data = nullptr;
if (!is_test) {
if (!global_stats) {
batch_mean_data = batch_mean->mutable_data<T>(ctx.GetPlace());
batch_variance_data = batch_variance->mutable_data<T>(ctx.GetPlace());
}
auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto propagation = global_stats == true
? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto src_tz = paddle::framework::vectorize2int(x->dims());
auto scale_tz = paddle::framework::vectorize2int(scale->dims());
......@@ -199,7 +202,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
shift->data<T>() + ic, &scaleshift_data);
unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats;
if (global_stats) flags |= mkldnn::use_global_stats;
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor
......@@ -208,7 +211,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// keys for backward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, is_test, input_format,
src_tz, epsilon, flags, global_stats, input_format,
ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
......@@ -239,7 +242,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data);
std::shared_ptr<batch_norm_fwd> batch_norm_p;
if (is_test) {
if (global_stats) {
// create mkldnn memory for stats (as input)
std::shared_ptr<memory> mean_memory =
handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data));
......@@ -269,7 +272,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*batch_norm_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
if (!is_test) {
if (!global_stats) {
// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(batch_mean_data, ic);
......
......@@ -159,6 +159,14 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("fuse_with_relu",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("use_global_stats",
"(bool, default false) Whether to use global mean and "
"variance. In inference or test mode, set use_global_stats "
"to true or is_test true. the behavior is equivalent. "
"In train mode, when setting use_global_stats True, the "
"global mean and variance are also used during train time, "
"the BN acts as scaling and shiffting.")
.SetDefault(false);
AddComment(R"DOC(
Batch Normalization.
......@@ -190,6 +198,10 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
bool global_stats = is_test || use_global_stats;
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
......@@ -217,7 +229,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
if (!is_test) {
if (!global_stats) {
// saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e(
saved_mean->mutable_data<T>(ctx.GetPlace()), C);
......@@ -234,7 +246,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
if ((N * sample_size) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x.";
framework::TensorCopySync(*x, ctx.GetPlace(), y);
framework::TensorCopy(*x, ctx.GetPlace(), y);
return;
}
......@@ -277,7 +289,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
// use SavedMean and SavedVariance to do normalize
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C);
if (is_test) {
if (global_stats) {
ConstEigenVectorArrayMap<T> var_arr(
ctx.Input<Tensor>("Variance")->data<T>(), C);
inv_std = (var_arr + epsilon).sqrt().inverse();
......@@ -289,8 +301,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
inv_std = saved_inv_std;
}
ConstEigenVectorArrayMap<T> mean_arr(
is_test ? ctx.Input<Tensor>("Mean")->data<T>()
: ctx.Output<Tensor>("SavedMean")->data<T>(),
global_stats ? ctx.Input<Tensor>("Mean")->data<T>()
: ctx.Output<Tensor>("SavedMean")->data<T>(),
C);
// ((x - est_mean) * (inv_var) * scale + bias
......@@ -336,15 +348,27 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Scale"), "");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "");
PADDLE_ENFORCE(ctx->HasInput("SavedMean"), "");
PADDLE_ENFORCE(ctx->HasInput("SavedVariance"), "");
PADDLE_ENFORCE(ctx->HasInput("Scale"), "Input(scale) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("SavedMean"),
"Input(SavedMean) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("SavedVariance"),
"Input(SavedVariance) should not be null");
// check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Scale")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), "");
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Output(Scale@GRAD) and Output(Bias@GRAD) should not be "
"null at same time");
}
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"),
"Using global stats during training is not supported "
"in gradient op kernel of batch_norm_mkldnn_op now.");
}
const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout(
......@@ -354,8 +378,10 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
}
protected:
......@@ -405,6 +431,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
// SavedVariance have been reverted in forward operator
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
......@@ -419,38 +447,60 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
: x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C;
ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
ConstEigenVectorArrayMap<T> mean_arr(saved_mean->data<T>(), C);
ConstEigenVectorArrayMap<T> inv_var_arr(saved_inv_variance->data<T>(), C);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
const T *mean_data = saved_mean->data<T>();
const T *inv_var_data = saved_inv_variance->data<T>();
Tensor inv_var_tensor;
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
mean_data = running_mean->data<T>();
T *running_inv_var_data = inv_var_tensor.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> inv_var_tmp(running_inv_var_data, C);
ConstEigenVectorArrayMap<T> var_arr(running_variance->data<T>(), C);
inv_var_tmp = (var_arr + epsilon).sqrt().inverse().eval();
inv_var_data = running_inv_var_data;
}
ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);
T *d_bias_data = nullptr;
T *d_scale_data = nullptr;
if (d_scale && d_bias) {
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
d_bias_data = d_bias->mutable_data<T>(ctx.GetPlace());
d_scale_data = d_scale->mutable_data<T>(ctx.GetPlace());
}
// d_bias = np.sum(d_y, axis=0)
// d_scale = np.sum((X - mean) / inv_std * dy, axis=0)
// d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0)
// - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0))
EigenVectorArrayMap<T> d_bias_arr(d_bias_data, C);
EigenVectorArrayMap<T> d_scale_arr(d_scale_data, C);
EigenVectorArrayMap<T> d_bias_arr(d_bias->mutable_data<T>(ctx.GetPlace()),
C);
EigenVectorArrayMap<T> d_scale_arr(d_scale->mutable_data<T>(ctx.GetPlace()),
C);
d_bias_arr.setZero();
d_scale_arr.setZero();
if (d_scale && d_bias) {
d_bias_arr.setZero();
d_scale_arr.setZero();
}
if ((N * sample_size) == 1) {
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
if ((N * sample_size) == 1 && !use_global_stats) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
return;
}
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size);
int scale_coefff = use_global_stats ? 1 : N * sample_size;
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / scale_coefff;
switch (data_layout) {
case DataLayout::kNCHW: {
......@@ -460,19 +510,29 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
sample_size, N * C);
d_x_arr.setZero();
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_bias_arr(c) += d_y_arr.col(nc).sum();
d_scale_arr(c) +=
((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * d_y_arr.col(nc))
.sum();
if (d_scale && d_bias) {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_bias_arr(c) += d_y_arr.col(nc).sum();
d_scale_arr(c) += ((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) *
d_y_arr.col(nc))
.sum();
}
}
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) +=
scale_inv_var_nhw(c) *
(d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) -
(x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) * inv_var_arr(c));
if (!use_global_stats) {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) +=
scale_inv_var_nhw(c) *
(d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) -
(x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) *
inv_var_arr(c));
}
} else {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) += scale_inv_var_nhw(c) * d_y_arr.col(nc);
}
}
break;
}
......@@ -488,15 +548,27 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const auto d_y_mul_x_minus_mean_row_sum =
(d_y_arr * x_minus_mean).rowwise().sum();
const auto inv_var_sqr = inv_var_arr * inv_var_arr;
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_bias_arr += d_y_arr.col(nhw);
d_scale_arr +=
(x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw);
d_x_arr.col(nhw) +=
scale_inv_var_nhw *
(d_y_arr.col(nhw) * N * sample_size - d_y_row_sum -
x_minus_mean.col(nhw) * inv_var_sqr *
d_y_mul_x_minus_mean_row_sum);
if (d_scale && d_bias) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_bias_arr += d_y_arr.col(nhw);
d_scale_arr +=
(x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw);
}
}
if (!use_global_stats) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) +=
scale_inv_var_nhw *
(d_y_arr.col(nhw) * N * sample_size - d_y_row_sum -
x_minus_mean.col(nhw) * inv_var_sqr *
d_y_mul_x_minus_mean_row_sum);
}
} else {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) += scale_inv_var_nhw * d_y_arr.col(nhw);
}
}
break;
}
......@@ -522,6 +594,10 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput("SavedMean", Output("SavedMean"));
op->SetInput("SavedVariance", Output("SavedVariance"));
// used when setting use_global_stats True during training
op->SetInput("Mean", Output("MeanOut"));
op->SetInput("Variance", Output("VarianceOut"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
......
......@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
......@@ -59,6 +63,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
......@@ -121,7 +126,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto handle = dev_ctx.cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths.
if (is_test) {
if (is_test || use_global_stats) {
// only when test we use input to do computation.
const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance");
......@@ -163,7 +168,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
if ((N * H * W * D) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x.";
framework::TensorCopySync(*x, ctx.GetPlace(), y);
framework::TensorCopy(*x, ctx.GetPlace(), y);
} else {
double this_factor = 1. - momentum;
......@@ -191,6 +196,58 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
}
};
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *variance,
const double epsilon, const int C,
const int HxW, const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> inv_var = 1.0 / sqrt(variance[c] + epsilon);
dx[i] = static_cast<T>(static_cast<BatchNormParamType<T>>(dy[i]) *
scale[c] * inv_var);
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(
const T *dy, const T *x, const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance, const double epsilon, const int N,
const int C, const int HxW, BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
BatchNormParamType<T> mean_i = mean[i];
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
ds_sum += static_cast<BatchNormParamType<T>>(dy[index]) *
(static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
db_sum += static_cast<BatchNormParamType<T>>(dy[index]);
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
dscale[i] = ds_sum * inv_var_i;
dbias[i] = db_sum;
}
__syncthreads();
}
}
template <typename T>
class BatchNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -200,6 +257,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
"It must use CUDAPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *x = ctx.Input<Tensor>("X");
......@@ -219,42 +278,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if ((N * H * W * D) == 1) {
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
return;
if (d_scale && d_bias) {
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
}
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
std::vector<int> dims;
std::vector<int> strides;
if (data_layout == DataLayout::kNCHW) {
......@@ -264,34 +294,114 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data =
saved_mean->template data<BatchNormParamType<T>>();
const void *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (!use_global_stats) {
if ((N * H * W * D) == 1) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
return;
}
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data =
saved_mean->template data<BatchNormParamType<T>>();
const void *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
// clean when exit.
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
} else {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_var = ctx.Input<Tensor>("Variance");
const auto *running_mean_data =
running_mean->template data<BatchNormParamType<T>>();
const auto *running_var_data =
running_var->template data<BatchNormParamType<T>>();
const int num = x->numel();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
if (data_layout == framework::DataLayout::kNCHW) {
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNCHW><<<
grid1, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
running_var_data, epsilon, C, H * W, num, d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
epsilon, C, H * W, num, d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNHWC><<<
grid1, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
running_var_data, epsilon, C, H * W, num, d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
epsilon, C, H * W, num, d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
}
}
}
};
......
......@@ -2300,7 +2300,8 @@ def batch_norm(input,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=False,
fuse_with_relu=False):
fuse_with_relu=False,
use_global_stats=False):
"""
**Batch Normalization Layer**
......@@ -2327,6 +2328,19 @@ def batch_norm(input,
\\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift
When use_global_stats = True, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch.
They are global (or running) statistics. (It usually got from the
pre-trained model.)
The training and testing (or inference) have the same behavior:
.. math::
\\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\
\\sigma_{\\beta}^{2} + \\epsilon}} \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta
Args:
input(variable): The input variable which is a LoDTensor.
act(string, Default None): Activation type, linear|relu|prelu|...
......@@ -2349,6 +2363,11 @@ def batch_norm(input,
moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance.
do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not.
fuse_with_relu (bool): if True, this OP performs relu after batch norm.
use_global_stats(bool, Default False): Whether to use global mean and
variance. In inference or test mode, set use_global_stats to true
or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period.
Returns:
Variable: A tensor variable which is the result after applying batch normalization on the input.
......@@ -2381,9 +2400,15 @@ def batch_norm(input,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
# setting stop_gradient=True to reduce computation
if use_global_stats and helper.param_attr.learning_rate == 0.:
scale.stop_gradient = True
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
# setting stop_gradient=True to reduce computation
if use_global_stats and helper.bias_attr.learning_rate == 0.:
scale.stop_gradient = True
mean = helper.create_parameter(
attr=ParamAttr(
......@@ -2439,7 +2464,8 @@ def batch_norm(input,
"epsilon": epsilon,
"is_test": is_test,
"use_mkldnn": False,
"fuse_with_relu": fuse_with_relu
"fuse_with_relu": fuse_with_relu,
"use_global_stats": use_global_stats
})
return helper.append_activation(batch_norm_out)
......
......@@ -54,6 +54,19 @@ def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
return y
def _cal_mean_variance(x, epsilon, data_format):
assert data_format in ['NCHW', 'NHWC']
x_square = x * x
axis = (0, 2, 3) if data_format == 'NCHW' else (0, 1, 2)
C = x.shape[1] if data_format == 'NCHW' else x.shape[-1]
x_square_sum = np.sum(x_square, axis)
x_sum = np.sum(x, axis=axis)
element_count = np.size(x) / C
mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean
return mean, var
def _reference_training(x, scale, offset, epsilon, data_format):
x_shape = x.shape
......@@ -294,7 +307,18 @@ class TestBatchNormOpTraining(unittest.TestCase):
self.use_mkldnn = False
self.fuse_with_relu = False
self.data_formats = ["NCHW", "NHWC"]
self.momentum = 0.9
self.epsilon = 0.00001
self.init_kernel_type()
self.init_test_case()
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set()
self.fetch_list = [
'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD',
'scale@GRAD', 'bias@GRAD'
]
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
np.allclose(np.array(tensor), np_array, atol=atol)
......@@ -313,11 +337,22 @@ class TestBatchNormOpTraining(unittest.TestCase):
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
def set_mean_variance(self, scale_shape, x, data_layout):
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
# computing global mean/variance for one step
if self.use_global_stats:
mom = self.momentum
x_mean, x_var = _cal_mean_variance(x, self.epsilon, data_layout)
mean = x_mean * (1. - mom) + mom * mean
variance = x_var * (1. - mom) + mom * variance
return mean, variance
def test_forward_backward(self):
def test_with_place(place, data_layout, shape):
# attr
epsilon = 0.00001
momentum = 0.9
epsilon = self.epsilon
momentum = self.momentum
if data_layout == "NCHW":
n, c, h, w = shape[0], shape[1], shape[2], shape[3]
else:
......@@ -328,9 +363,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
x = np.random.random_sample(shape).astype(np.float32)
scale = np.random.random_sample(scale_shape).astype(np.float32)
bias = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
mean, variance = self.set_mean_variance(scale_shape, x, data_layout)
y_grad = np.random.random_sample(shape).astype(np.float32)
y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward(
......@@ -339,6 +372,9 @@ class TestBatchNormOpTraining(unittest.TestCase):
var_dict = locals()
var_dict['y@GRAD'] = y_grad
var_dict['x@GRAD'] = x_grad
var_dict['scale@GRAD'] = scale_grad
var_dict['bias@GRAD'] = bias_grad
var_names = [
'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean',
......@@ -365,9 +401,8 @@ class TestBatchNormOpTraining(unittest.TestCase):
},
outputs={
"Y": block.var('y'),
"MeanOut": block.var('mean'), # share the same memory
"VarianceOut":
block.var('variance'), # share the same memory
"MeanOut": block.var('mean'), # share memory
"VarianceOut": block.var('variance'), # share memory
"SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance')
},
......@@ -377,13 +412,14 @@ class TestBatchNormOpTraining(unittest.TestCase):
"is_test": False,
"data_layout": data_layout,
"use_mkldnn": self.use_mkldnn,
"fuse_with_relu": self.fuse_with_relu
"fuse_with_relu": self.fuse_with_relu,
"use_global_stats": self.use_global_stats
})
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
# generate backward op_desc
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
bn_op.desc, set(), [])
bn_op.desc, self.no_grad_set, [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
......@@ -403,20 +439,10 @@ class TestBatchNormOpTraining(unittest.TestCase):
for name in
['x', 'scale', 'bias', 'mean', 'variance', 'y@GRAD']
},
fetch_list=[
'y', 'mean', 'variance', 'saved_mean', 'saved_variance',
'x@GRAD', 'scale@GRAD', 'bias@GRAD'
])
self.__assert_close(y, out[0], "y")
self.__assert_close(mean_out, out[1], "mean")
self.__assert_close(variance_out, out[2], "variance", 1e-3)
self.__assert_close(saved_mean, out[3], "saved_mean")
self.__assert_close(saved_variance, out[4], "saved_variance", 1e-3)
self.__assert_close(x_grad, out[5], "x_grad")
self.__assert_close(scale_grad, out[6], "scale_grad")
self.__assert_close(bias_grad, out[7], "bias_grad")
fetch_list=self.fetch_list)
for id, name in enumerate(self.fetch_list):
self.__assert_close(var_dict[name], out[id], name)
print("op test forward passed: ", str(place), data_layout)
places = [core.CPUPlace()]
......@@ -432,5 +458,66 @@ class TestBatchNormOpTraining(unittest.TestCase):
pass
class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = True
self.no_grad_set = set()
self.fetch_list = [
'y', 'mean', 'variance', 'x@GRAD', 'scale@GRAD', 'bias@GRAD'
]
def reference_grad(self, x, y_grad, scale, mean, var, epsilon, data_format):
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
y_grad = np.transpose(y_grad, (0, 2, 3, 1))
x_grad = scale * y_grad / np.sqrt(var + epsilon)
grad_scale = np.sum(y_grad * (x - mean) / np.sqrt(var + epsilon),
axis=(0, 1, 2))
grad_offset = np.sum(y_grad, axis=(0, 1, 2))
# transfer back to N, C, H, W
if data_format == "NCHW":
x_grad = np.transpose(x_grad, (0, 3, 1, 2))
x = np.transpose(x, (0, 3, 1, 2))
y_grad = np.transpose(y_grad, (0, 3, 1, 2))
return x_grad, grad_scale, grad_offset
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
epsilon, momentum, shape, data_layout):
if data_layout != "NCHW" and data_layout != "NHWC":
raise ValueError("Unknown data order.")
if data_layout == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
# run normalizaton
normalized = (x - mean) / np.sqrt(variance + epsilon)
y = normalized * scale + bias
# transfer back to N, C, H, W
if data_layout == "NCHW":
x = np.transpose(x, (0, 3, 1, 2))
y = np.transpose(y, (0, 3, 1, 2))
mean_out = mean
variance_out = variance
saved_variance = 1. / np.sqrt(variance + epsilon)
# run backward
x_grad, scale_grad, bias_grad = self.reference_grad(
x, y_grad, scale, mean, variance, epsilon, data_layout)
return y, mean_out, variance_out, mean, saved_variance, x_grad, scale_grad, bias_grad
class TestBatchNormOpFreezeStatsAndScaleBiasTraining(
TestBatchNormOpFreezeStatsTraining):
def init_test_case(self):
self.use_global_stats = True
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
if __name__ == '__main__':
unittest.main()
......@@ -955,6 +955,15 @@ class TestBook(unittest.TestCase):
print(str(program))
def test_batch_norm(self):
program = Program()
with program_guard(program):
data = layers.data(
name='data', shape=[32, 128, 128], dtype="float32")
out = layers.batch_norm(data)
print(str(program))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册