未验证 提交 b19a5979 编写于 作者: C ceci3 提交者: GitHub

[cherry-pick 1.8]add batch norm double grad, test=release/1.8 (#27441)

* add batch norm double grad, test=release/1.8

* update, test=develop

* fix, test=develop
上级 8e1712a7
...@@ -831,6 +831,414 @@ void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const { ...@@ -831,6 +831,414 @@ void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const {
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
} }
template <typename T>
void BatchNormDoubleGradMaker<T>::Apply(GradOpPtr<T> op) const {
op->SetType("batch_norm_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("SavedMean", this->Input("SavedMean"));
op->SetInput("SavedVariance", this->Input("SavedVariance"));
if (boost::get<bool>(this->GetAttr("use_global_stats"))) {
op->SetInput("Mean", this->Input("Mean"));
op->SetInput("Variance", this->Input("Variance"));
}
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DDScale", this->OutputGrad(framework::GradVarName("Scale")));
op->SetInput("DDBias", this->OutputGrad(framework::GradVarName("Bias")));
op->SetInput("DY", this->Input(framework::GradVarName("Y")));
op->SetAttrMap(this->Attrs());
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DScale", this->InputGrad("Scale"));
op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y")));
}
void BatchNormDoubleGradOp::InferShape(
framework::InferShapeContext *ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
"BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
"BatchNormDoubleGrad");
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "VarianceOut",
"BatchNormDoubleGrad");
}
OP_INOUT_CHECK(ctx->HasInput("DY"), "Input", "DY", "BatchNormDoubleGrad");
// check output
OP_INOUT_CHECK(ctx->HasOutput("DX"), "Output", "DX", "BatchNormDoubleGrad");
const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
if (ctx->HasOutput("DX")) {
ctx->SetOutputDim("DX", x_dims);
}
if (ctx->HasOutput("DScale")) {
ctx->SetOutputDim("DScale", {C});
}
if (ctx->HasOutput("DDY")) {
ctx->ShareDim("X", "DDY");
}
}
framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar("DY");
if (var == nullptr) {
PADDLE_THROW(
platform::errors::NotFound("cannot find gradient variable of Y"));
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
template <typename T>
class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *X = ctx.Input<Tensor>("X");
const auto *Scale = ctx.Input<Tensor>("Scale");
const auto *dY = ctx.Input<Tensor>("DY");
const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
const float epsilon = ctx.Attr<float>("epsilon");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *ddX = ctx.Input<Tensor>("DDX");
const auto *ddScale = ctx.Input<Tensor>("DDScale");
const auto *ddBias = ctx.Input<Tensor>("DDBias");
auto *dX = ctx.Output<Tensor>("DX");
auto *dScale = ctx.Output<Tensor>("DScale");
auto *ddY = ctx.Output<Tensor>("DDY");
dX->mutable_data<T>(ctx.GetPlace());
ddY->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
const auto &x_dims = X->dims();
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int sample_size = X->numel() / C;
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
const T *mean_data = Saved_mean->data<T>();
const T *inv_var_data = Saved_variance->data<T>();
Tensor inv_var_tensor;
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
mean_data = running_mean->data<T>();
inv_var_tensor.Resize({C});
T *running_inv_var_data = inv_var_tensor.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> inv_var_tmp(running_inv_var_data, C);
ConstEigenVectorArrayMap<T> var_arr(running_variance->data<T>(), C);
inv_var_tmp = (var_arr + epsilon).sqrt().inverse();
inv_var_data = running_inv_var_data;
}
// transpose NCHW -> NHWC for easy calculate
Tensor transformed_x(X->type());
Tensor transformed_dy(dY->type());
Tensor transformed_ddx(ddX->type());
Tensor transformed_dx(dX->type());
Tensor transformed_ddy(ddY->type());
if (data_layout == DataLayout::kNCHW && x_dims.size() > 2) {
VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
// Input Tensor
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, X,
&transformed_x);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, X, &transformed_x);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, dY,
&transformed_dy);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, dY,
&transformed_dy);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, ddX,
&transformed_ddx);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, ddX,
&transformed_ddx);
// Output Tensor
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, dX,
&transformed_dx);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, ddY,
&transformed_ddy);
} else {
transformed_x.ShareDataWith(*X);
transformed_dy.ShareDataWith(*dY);
transformed_ddx.ShareDataWith(*ddX);
transformed_dx.ShareDataWith(*dX);
transformed_ddy.ShareDataWith(*ddY);
}
ConstEigenArrayMap<T> x_arr(transformed_x.data<T>(), C, sample_size);
ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);
Tensor mean_tile;
mean_tile.Resize({C, sample_size});
mean_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> mean_tile_data(mean_tile.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
Tensor inv_var_tile;
inv_var_tile.Resize({C, sample_size});
inv_var_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> inv_var_tile_data(
inv_var_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
mean_tile_data = mean_arr.replicate(1, sample_size);
inv_var_tile_data = inv_var_arr.replicate(1, sample_size);
Tensor Scale_data;
if (!Scale) {
Scale_data.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &Scale_data, static_cast<T>(1));
}
ConstEigenVectorArrayMap<T> scale_arr(
Scale ? Scale->data<T>() : Scale_data.data<T>(), C);
Tensor scale_tile;
scale_tile.Resize({C, sample_size});
scale_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> scale_tile_data(scale_tile.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
scale_tile_data = scale_arr.replicate(1, sample_size);
ConstEigenArrayMap<T> dy_arr(transformed_dy.data<T>(), C, sample_size);
ConstEigenArrayMap<T> ddx_arr(transformed_ddx.data<T>(), C, sample_size);
Tensor x_sub_mean_mul_invstd;
x_sub_mean_mul_invstd.Resize({C, sample_size});
x_sub_mean_mul_invstd.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> x_sub_mean_mul_invstd_arr(
x_sub_mean_mul_invstd.mutable_data<T>(ctx.GetPlace()), C, sample_size);
x_sub_mean_mul_invstd_arr = (x_arr - mean_tile_data) * inv_var_tile_data;
if (dX) {
dX->mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> dx_arr(transformed_dx.mutable_data<T>(ctx.GetPlace()), C,
sample_size);
dx_arr.setZero();
if (use_global_stats) {
// math: dx = (ddscale * dy) * inv_var
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
dx_arr = dy_arr * ddscale_tile_data * inv_var_tile_data;
}
} else {
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
// np.sum(dy, axis=(n,h,w)) -
// np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x -
// mean),
// axis=(n,h,w)) * inv_var.pow(2) *
// np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) /
// NxHxW *
// np.sum(ddx * (x - mean)) *
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w)))
if (ddX) {
dx_arr +=
(x_sub_mean_mul_invstd_arr * inv_var_tile_data *
inv_var_tile_data / sample_size)
.colwise() *
(ddx_arr.rowwise().sum() * dy_arr.rowwise().sum() / sample_size -
(dy_arr * ddx_arr).rowwise().sum() +
3. * (dy_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() *
(ddx_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size);
dx_arr += (inv_var_tile_data * inv_var_tile_data).colwise() *
(ddx_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size *
(dy_arr.rowwise().sum() / sample_size - dy_arr);
dx_arr += (inv_var_tile_data * inv_var_tile_data).colwise() *
(dy_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size *
(ddx_arr.rowwise().sum() / sample_size - ddx_arr);
dx_arr = scale_tile_data * dx_arr;
}
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
dx_arr += (dy_arr * inv_var_tile_data -
(dy_arr.rowwise().sum().replicate(1, sample_size) /
sample_size) *
inv_var_tile_data -
x_sub_mean_mul_invstd_arr * inv_var_tile_data *
(dy_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size) *
ddscale_tile_data;
}
}
if (data_layout == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
TransToChannelFirst<paddle::platform::CPUDeviceContext, T>(
ctx, &transformed_dx, dX);
}
}
if (dScale) {
dScale->mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> dscale_arr(dScale->mutable_data<T>(ctx.GetPlace()),
C);
dscale_arr.setZero();
if (use_global_stats) {
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
if (ddX) {
dscale_arr = (ddx_arr * dy_arr * inv_var_tile_data).rowwise().sum();
}
} else {
// math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) *
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx
if (ddX) {
Tensor first_grad;
first_grad.Resize({C, sample_size});
EigenArrayMap<T> first_grad_arr(
first_grad.mutable_data<T>(ctx.GetPlace()), C, sample_size);
first_grad_arr.setZero();
first_grad_arr +=
inv_var_tile_data *
(dy_arr -
dy_arr.rowwise().sum().replicate(1, sample_size) / sample_size -
x_sub_mean_mul_invstd_arr *
(dy_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size);
dscale_arr = (first_grad_arr * ddx_arr).rowwise().sum();
}
}
}
if (ddY) {
ddY->mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> ddy_arr(transformed_ddy.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
ddy_arr.setZero();
if (use_global_stats) {
// math: ddy = r * ddx * inv_var + ddbias +
// ddscale * (x - mean) * inv_var
if (ddX) {
ddy_arr = scale_tile_data * ddx_arr * inv_var_tile_data;
}
} else {
// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w)))
if (ddX) {
ddy_arr +=
scale_tile_data * inv_var_tile_data *
(ddx_arr -
ddx_arr.rowwise().sum().replicate(1, sample_size) / sample_size -
x_sub_mean_mul_invstd_arr *
(ddx_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size);
}
}
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data;
}
if (ddBias) {
ConstEigenVectorArrayMap<T> ddbias_arr(ddBias->data<T>(), C);
Tensor ddbias_tile;
ddbias_tile.Resize({C, sample_size});
EigenArrayMap<T> ddbias_tile_data(
ddbias_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddbias_tile_data = ddbias_arr.replicate(1, sample_size);
ddy_arr += ddbias_tile_data;
}
if (data_layout == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
TransToChannelFirst<paddle::platform::CPUDeviceContext, T>(
ctx, &transformed_ddy, ddY);
}
}
}
};
DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -839,7 +1247,11 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, ...@@ -839,7 +1247,11 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormOpInferVarType,
ops::BatchNormGradMaker<paddle::framework::OpDesc>, ops::BatchNormGradMaker<paddle::framework::OpDesc>,
ops::BatchNormGradMaker<paddle::imperative::OpBase>); ops::BatchNormGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp,
ops::BatchNormDoubleGradMaker<paddle::framework::OpDesc>,
ops::BatchNormDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_norm_grad_grad, ops::BatchNormDoubleGradOp,
ops::BatchNormDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>, batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -848,3 +1260,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -848,3 +1260,7 @@ REGISTER_OP_CPU_KERNEL(
batch_norm_grad, batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>, ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>); ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -840,6 +841,45 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -840,6 +841,45 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
}; };
template <typename T>
class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *X = ctx.Input<Tensor>("X");
const auto *Scale = ctx.Input<Tensor>("Scale");
const auto *dY = ctx.Input<Tensor>("DY");
const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
const double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *ddX = ctx.Input<Tensor>("DDX");
const auto *ddScale = ctx.Input<Tensor>("DDScale");
const auto *ddBias = ctx.Input<Tensor>("DDBias");
auto *dX = ctx.Output<Tensor>("DX");
auto *dScale = ctx.Output<Tensor>("DScale");
auto *ddY = ctx.Output<Tensor>("DDY");
NormDoubleGradFunctor<platform::CUDADeviceContext, T>(
ctx, data_layout, X, Scale, dY, Saved_mean, Saved_variance, epsilon,
use_global_stats, ddX, ddScale, ddBias, dX, dScale, ddY);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -853,3 +893,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -853,3 +893,7 @@ REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>, batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>, ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>); ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, double>);
...@@ -103,6 +103,42 @@ inline void TransToChannelFirst(const framework::ExecutionContext& context, ...@@ -103,6 +103,42 @@ inline void TransToChannelFirst(const framework::ExecutionContext& context,
} }
} }
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context, inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) { const Tensor* input, Tensor* transformed_input) {
...@@ -154,6 +190,16 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -154,6 +190,16 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
const framework::OpKernelType& expected_kernel_type) const override; const framework::OpKernelType& expected_kernel_type) const override;
}; };
class BatchNormDoubleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override; void Make() override;
...@@ -168,6 +214,15 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -168,6 +214,15 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override; void Apply(GradOpPtr<T> op) const override;
}; };
template <typename T>
class BatchNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override;
};
class BatchNormOpInferVarType class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
...@@ -190,5 +245,11 @@ class BatchNormGradKernel : public framework::OpKernel<T> { ...@@ -190,5 +245,11 @@ class BatchNormGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
}; };
template <typename DeviceContext, typename T>
class BatchNormDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -475,11 +475,12 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T> ...@@ -475,11 +475,12 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
// (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW * // (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW *
// np.sum(dy, // np.sum(dy,
// axis=(h,w)) * (x - mean) * // axis=(h,w)) * (x - mean) *
// (np.mean(ddx, axis=(h,w)) - ddx) + ddr * (dy * inv_var - inv_var // (np.mean(ddx, axis=(h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var
// * // *
// np.mean(dy, axis=(h,w)) - // np.mean(dy, axis=(h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), // inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(h,w)))) // axis=(h,w)))
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> set_constant; math::SetConstant<platform::CPUDeviceContext, T> set_constant;
...@@ -553,9 +554,13 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T> ...@@ -553,9 +554,13 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
first_grad_arr += first_grad_arr +=
inv_var_tile_data * inv_var_tile_data *
(dy_arr - dy_arr.colwise().sum() / sample_size - (dy_arr -
dy_arr.colwise().sum().replicate(sample_size, 1) / sample_size -
x_sub_mean_mul_invstd_arr * x_sub_mean_mul_invstd_arr *
(dy_arr * x_sub_mean_mul_invstd_arr).colwise().sum() / (dy_arr * x_sub_mean_mul_invstd_arr)
.colwise()
.sum()
.replicate(sample_size, 1) /
sample_size); sample_size);
first_grad_arr = first_grad_arr * ddx_arr; first_grad_arr = first_grad_arr * ddx_arr;
for (int nc = 0; nc < NxC; ++nc) { for (int nc = 0; nc < NxC; ++nc) {
......
...@@ -467,8 +467,8 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean, ...@@ -467,8 +467,8 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
if (ddscale != nullptr) { if (ddscale != nullptr) {
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int i = beg_idx; i < end_idx; i += BlockDim) {
dx[i] += (dy[i] * var_val - dy_sum_val / sample_size * var_val - dx[i] += (dy[i] * var_val - dy_sum_val / sample_size * var_val -
(x[i] - mean_val) * var_val * dy_mul_x_sub_mean_sum_val * (x[i] - mean_val) * var_val * var_val *
var_val / sample_size) * dy_mul_x_sub_mean_sum_val * var_val / sample_size) *
ddscale[c]; ddscale[c];
} }
} }
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
// np.sum(dy, axis=(n,h,w)) -
// np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x -
// mean),
// axis=(n,h,w)) * inv_var.pow(2) *
// np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) /
// NxHxW *
// np.sum(ddx * (x - mean)) *
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDX(const T *x, const T *mean,
const T *variance, const T *ddx,
const T *dy, const T *scale,
const T *ddscale, const int N, const int C,
const int sample_size, const double epsilon,
T *dx) {
const int outer_size = C;
const int inner_size = N * sample_size;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage dy_storage;
__shared__ typename BlockReduce::TempStorage ddx_storage;
__shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage;
__shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
__shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
__shared__ T dy_sum_val;
__shared__ T ddx_sum_val;
__shared__ T dy_mul_ddx_sum_val;
__shared__ T dy_mul_x_sub_mean_sum_val;
__shared__ T ddx_mul_x_sub_mean_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T mean_val = mean[i];
T var_val = variance[i];
T dy_sum = 0;
T ddx_sum = 0;
T dy_mul_ddx_sum = 0;
T dy_mul_x_sub_mean_sum = 0;
T ddx_mul_x_sub_mean_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
T ddx_i = ddx[index];
T dy_i = dy[index];
T tmp = x[index] - mean_val;
dy_sum += dy_i;
ddx_sum += ddx_i;
dy_mul_ddx_sum += (ddx_i * dy_i);
dy_mul_x_sub_mean_sum += (dy_i * tmp);
ddx_mul_x_sub_mean_sum += (ddx_i * tmp);
}
dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
dy_mul_ddx_sum =
BlockReduce(dy_mul_ddx_storage).Reduce(dy_mul_ddx_sum, cub::Sum());
dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
.Reduce(dy_mul_x_sub_mean_sum, cub::Sum());
ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
.Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
dy_sum_val = dy_sum;
ddx_sum_val = ddx_sum;
dy_mul_ddx_sum_val = dy_mul_ddx_sum;
dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
}
__syncthreads();
if (ddx != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
dx[index] +=
((x[index] - mean_val) * var_val * var_val * var_val / inner_size *
(ddx_sum_val * dy_sum_val / inner_size - dy_mul_ddx_sum_val +
3. * dy_mul_x_sub_mean_sum_val * var_val *
ddx_mul_x_sub_mean_sum_val * var_val / inner_size) +
ddx_mul_x_sub_mean_sum_val * var_val / inner_size * var_val *
var_val * (dy_sum_val / inner_size - dy[index]) +
dy_mul_x_sub_mean_sum_val * var_val / inner_size * var_val *
var_val * (ddx_sum_val / inner_size - ddx[index])) *
scale[i];
}
}
__syncthreads();
if (ddscale != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val -
(x[index] - mean_val) * var_val * var_val *
dy_mul_x_sub_mean_sum_val * var_val / inner_size) *
ddscale[i];
}
}
}
}
// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDDY(const T *x, const T *mean,
const T *variance, const T *ddscale,
const T *ddbias, const T *ddx,
const T *scale, const int N, const int C,
const int sample_size,
const double epsilon, T *ddy) {
const int outer_size = C;
const int inner_size = N * sample_size;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ddx_storage;
__shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
__shared__ T ddx_sum_val;
__shared__ T ddx_mul_x_sub_mean_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T mean_val = mean[i];
T var_val = variance[i];
T ddx_sum = 0;
T ddx_mul_x_sub_mean_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
T ddx_i = ddx[index];
ddx_sum += ddx_i;
ddx_mul_x_sub_mean_sum += (ddx_i * (x[index] - mean_val));
}
ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
.Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
ddx_sum_val = ddx_sum;
ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
}
__syncthreads();
if (ddx != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
ddy[index] += scale[i] * var_val *
(ddx[index] - ddx_sum_val / inner_size -
(x[index] - mean_val) * var_val *
ddx_mul_x_sub_mean_sum_val * var_val / inner_size);
}
}
__syncthreads();
if (ddscale != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
ddy[index] += (x[index] - mean_val) * var_val * ddscale[i];
}
}
__syncthreads();
if (ddbias != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
ddy[index] += ddbias[i];
}
}
}
}
// math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) *
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScale(const T *x, const T *mean,
const T *variance, const T *ddx,
const T *dy, const int N, const int C,
const int sample_size,
const double epsilon, T *dscale) {
const int outer_size = C;
const int inner_size = N * sample_size;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage dy_storage;
__shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
__shared__ typename BlockReduce::TempStorage dscale_tmp_storage;
__shared__ T dy_sum_val;
__shared__ T dy_mul_x_sub_mean_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T dy_sum = 0;
T dy_mul_x_sub_mean_sum = 0;
T mean_val = mean[i];
T var_val = variance[i];
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
T dy_i = dy[index];
dy_sum += dy_i;
dy_mul_x_sub_mean_sum += (dy_i * (x[index] - mean_val));
}
dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
.Reduce(dy_mul_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
dy_sum_val = dy_sum;
dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
}
__syncthreads();
if (ddx != nullptr) {
T dscale_tmp = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
dscale_tmp += ddx[index] * var_val *
(dy[index] - dy_sum_val / inner_size -
dy_mul_x_sub_mean_sum_val * (x[index] - mean_val) *
var_val * var_val / inner_size);
}
dscale_tmp =
BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum());
if (threadIdx.x == 0) {
dscale[i] += dscale_tmp;
}
__syncthreads();
}
}
}
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScaleWithGlobal(
const T *ddx, const T *variance, const T *dy, const double epsilon,
const int N, const int C, const int sample_size, T *dscale) {
int outer_size = C;
int inner_size = N * sample_size;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ddx_mul_dy_storage;
__shared__ T ddx_mul_dy_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
T ddx_mul_dy_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
T ddx_i = ddx[index];
T dy_i = dy[index];
ddx_mul_dy_sum += (ddx_i * dy_i);
}
ddx_mul_dy_sum =
BlockReduce(ddx_mul_dy_storage).Reduce(ddx_mul_dy_sum, cub::Sum());
if (threadIdx.x == 0) {
ddx_mul_dy_sum_val = ddx_mul_dy_sum;
}
__syncthreads();
if (ddx != nullptr) {
dscale[i] = inv_var_i * ddx_mul_dy_sum_val;
}
}
}
// math: dx = ddscale * dy * inv_var
template <typename T, framework::DataLayout layout>
__global__ void DoubleGradComputeDXWithGlobal(const T *dy, const T *ddscale,
const T *variance,
const double epsilon, const int C,
const int sample_size,
const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
if (ddscale != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
dx[i] = dy[i] * ddscale[c] * inv_var;
}
}
}
// math: ddy = scale * ddx * inv_var + ddbias +
// ddscale * (x - mean) * inv_var
template <typename T, framework::DataLayout layout>
__global__ void DoubleGradComputeDDYWithGlobal(
const T *ddx, const T *scale, const T *mean, const T *variance, const T *x,
const T *ddbias, const T *ddscale, const double epsilon, const int C,
const int sample_size, const int num, T *ddy) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
if (ddx != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
ddy[i] += ddx[i] * scale[c] * inv_var;
}
}
__syncthreads();
if (ddscale != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
ddy[i] += (x[i] - mean[c]) * inv_var * ddscale[c];
}
}
__syncthreads();
if (ddbias != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
ddy[i] += ddbias[c];
}
}
}
template <typename DeviceContext, typename T>
void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
const DataLayout data_layout, const Tensor *X,
const Tensor *Scale, const Tensor *dY,
const Tensor *Saved_mean,
const Tensor *Saved_variance, const double epsilon,
const bool use_global_stats, const Tensor *ddX,
const Tensor *ddScale, const Tensor *ddBias,
Tensor *dX, Tensor *dScale, Tensor *ddY) {
const T *x_data = X->data<T>();
const T *dy_data = dY->data<T>();
const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data<T>());
const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data<T>());
const T *ddbias_data = (ddBias == nullptr ? nullptr : ddBias->data<T>());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
auto &x_dims = X->dims();
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int N = x_dims[0];
const int num = X->numel();
const int sample_size = num / N / C;
Tensor scale_tmp;
if (!Scale) {
scale_tmp.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
}
const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(C, max_blocks);
int grid1 = (num + block - 1) / block;
const T *mean_data, *variance_data;
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_var = ctx.Input<Tensor>("Variance");
const auto *running_mean_data = running_mean->template data<T>();
const auto *running_var_data = running_var->template data<T>();
mean_data = running_mean_data;
variance_data = running_var_data;
} else {
const T *smean_data = Saved_mean->data<T>();
const T *svariance_data = Saved_variance->data<T>();
mean_data = smean_data;
variance_data = svariance_data;
}
if (dX) {
T *dx_data = dX->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, dX, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDXWithGlobal<
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
dx_data);
} else {
DoubleGradComputeDXWithGlobal<
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
dx_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDX<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, scale_data,
ddscale_data, N, C, sample_size, epsilon, dx_data);
} else {
DoubleGradComputeDX<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, scale_data,
ddscale_data, N, C, sample_size, epsilon, dx_data);
}
}
}
if (dScale) {
T *dscale_data = dScale->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, dScale, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDScaleWithGlobal<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
ddx_data, variance_data, dy_data, epsilon, N, C, sample_size,
dscale_data);
} else {
DoubleGradComputeDScaleWithGlobal<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
ddx_data, variance_data, dy_data, epsilon, N, C, sample_size,
dscale_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDScale<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, N, C,
sample_size, epsilon, dscale_data);
} else {
DoubleGradComputeDScale<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, N, C,
sample_size, epsilon, dscale_data);
}
}
}
if (ddY) {
T *ddy_data = ddY->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, ddY, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDDYWithGlobal<
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data,
ddscale_data, epsilon, C, sample_size, num, ddy_data);
} else {
DoubleGradComputeDDYWithGlobal<
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data,
ddscale_data, epsilon, C, sample_size, num, ddy_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDDY<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddscale_data, ddbias_data,
ddx_data, scale_data, N, C, sample_size, epsilon, ddy_data);
} else {
DoubleGradComputeDDY<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddscale_data, ddbias_data,
ddx_data, scale_data, N, C, sample_size, epsilon, ddy_data);
}
}
}
}
} // namespace operators
} // namespace paddle
...@@ -49,5 +49,103 @@ class TestInstanceNormDoubleGradCheck(unittest.TestCase): ...@@ -49,5 +49,103 @@ class TestInstanceNormDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestBatchNormDoubleGradCheck(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
dtype = "float32"
eps = 0.005
atol = 2e-4
x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x')
z = fluid.layers.batch_norm(
input=x,
data_layout=self.data_layout,
use_global_stats=self.use_global_stats)
x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype)
gradient_checker.double_grad_check(
[x], z, x_init=x_arr, atol=atol, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestBatchNormDoubleGradCheckCase1(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NHWC'
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase2(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase3(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NHWC'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase4(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = False
self.shape = [2, 2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase5(TestBatchNormDoubleGradCheck):
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
dtype = "float32"
eps = 0.005
atol = 2e-4
chn = self.shape[1] if self.data_layout == 'NCHW' else self.shape[
-1]
x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x')
z = fluid.layers.batch_norm(
input=x,
data_layout=self.data_layout,
use_global_stats=self.use_global_stats)
x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype)
w, b = prog.global_block().all_parameters()[1:3]
w_arr = np.ones(chn).astype(dtype)
b_arr = np.zeros(chn).astype(dtype)
gradient_checker.double_grad_check(
[x, w, b],
z,
x_init=[x_arr, w_arr, b_arr],
atol=atol,
place=place,
eps=eps)
class TestBatchNormDoubleGradCheckCase6(TestBatchNormDoubleGradCheckCase5):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册