未验证 提交 b19a5979 编写于 作者: C ceci3 提交者: GitHub

[cherry-pick 1.8]add batch norm double grad, test=release/1.8 (#27441)

* add batch norm double grad, test=release/1.8

* update, test=develop

* fix, test=develop
上级 8e1712a7
...@@ -831,6 +831,414 @@ void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const { ...@@ -831,6 +831,414 @@ void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const {
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
} }
template <typename T>
void BatchNormDoubleGradMaker<T>::Apply(GradOpPtr<T> op) const {
op->SetType("batch_norm_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("SavedMean", this->Input("SavedMean"));
op->SetInput("SavedVariance", this->Input("SavedVariance"));
if (boost::get<bool>(this->GetAttr("use_global_stats"))) {
op->SetInput("Mean", this->Input("Mean"));
op->SetInput("Variance", this->Input("Variance"));
}
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DDScale", this->OutputGrad(framework::GradVarName("Scale")));
op->SetInput("DDBias", this->OutputGrad(framework::GradVarName("Bias")));
op->SetInput("DY", this->Input(framework::GradVarName("Y")));
op->SetAttrMap(this->Attrs());
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DScale", this->InputGrad("Scale"));
op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y")));
}
void BatchNormDoubleGradOp::InferShape(
framework::InferShapeContext *ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
"BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
"BatchNormDoubleGrad");
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "VarianceOut",
"BatchNormDoubleGrad");
}
OP_INOUT_CHECK(ctx->HasInput("DY"), "Input", "DY", "BatchNormDoubleGrad");
// check output
OP_INOUT_CHECK(ctx->HasOutput("DX"), "Output", "DX", "BatchNormDoubleGrad");
const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
if (ctx->HasOutput("DX")) {
ctx->SetOutputDim("DX", x_dims);
}
if (ctx->HasOutput("DScale")) {
ctx->SetOutputDim("DScale", {C});
}
if (ctx->HasOutput("DDY")) {
ctx->ShareDim("X", "DDY");
}
}
framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar("DY");
if (var == nullptr) {
PADDLE_THROW(
platform::errors::NotFound("cannot find gradient variable of Y"));
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
template <typename T>
class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *X = ctx.Input<Tensor>("X");
const auto *Scale = ctx.Input<Tensor>("Scale");
const auto *dY = ctx.Input<Tensor>("DY");
const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
const float epsilon = ctx.Attr<float>("epsilon");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *ddX = ctx.Input<Tensor>("DDX");
const auto *ddScale = ctx.Input<Tensor>("DDScale");
const auto *ddBias = ctx.Input<Tensor>("DDBias");
auto *dX = ctx.Output<Tensor>("DX");
auto *dScale = ctx.Output<Tensor>("DScale");
auto *ddY = ctx.Output<Tensor>("DDY");
dX->mutable_data<T>(ctx.GetPlace());
ddY->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
const auto &x_dims = X->dims();
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int sample_size = X->numel() / C;
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
const T *mean_data = Saved_mean->data<T>();
const T *inv_var_data = Saved_variance->data<T>();
Tensor inv_var_tensor;
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
mean_data = running_mean->data<T>();
inv_var_tensor.Resize({C});
T *running_inv_var_data = inv_var_tensor.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> inv_var_tmp(running_inv_var_data, C);
ConstEigenVectorArrayMap<T> var_arr(running_variance->data<T>(), C);
inv_var_tmp = (var_arr + epsilon).sqrt().inverse();
inv_var_data = running_inv_var_data;
}
// transpose NCHW -> NHWC for easy calculate
Tensor transformed_x(X->type());
Tensor transformed_dy(dY->type());
Tensor transformed_ddx(ddX->type());
Tensor transformed_dx(dX->type());
Tensor transformed_ddy(ddY->type());
if (data_layout == DataLayout::kNCHW && x_dims.size() > 2) {
VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
// Input Tensor
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, X,
&transformed_x);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, X, &transformed_x);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, dY,
&transformed_dy);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, dY,
&transformed_dy);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, ddX,
&transformed_ddx);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, ddX,
&transformed_ddx);
// Output Tensor
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, dX,
&transformed_dx);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, ddY,
&transformed_ddy);
} else {
transformed_x.ShareDataWith(*X);
transformed_dy.ShareDataWith(*dY);
transformed_ddx.ShareDataWith(*ddX);
transformed_dx.ShareDataWith(*dX);
transformed_ddy.ShareDataWith(*ddY);
}
ConstEigenArrayMap<T> x_arr(transformed_x.data<T>(), C, sample_size);
ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);
Tensor mean_tile;
mean_tile.Resize({C, sample_size});
mean_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> mean_tile_data(mean_tile.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
Tensor inv_var_tile;
inv_var_tile.Resize({C, sample_size});
inv_var_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> inv_var_tile_data(
inv_var_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
mean_tile_data = mean_arr.replicate(1, sample_size);
inv_var_tile_data = inv_var_arr.replicate(1, sample_size);
Tensor Scale_data;
if (!Scale) {
Scale_data.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &Scale_data, static_cast<T>(1));
}
ConstEigenVectorArrayMap<T> scale_arr(
Scale ? Scale->data<T>() : Scale_data.data<T>(), C);
Tensor scale_tile;
scale_tile.Resize({C, sample_size});
scale_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> scale_tile_data(scale_tile.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
scale_tile_data = scale_arr.replicate(1, sample_size);
ConstEigenArrayMap<T> dy_arr(transformed_dy.data<T>(), C, sample_size);
ConstEigenArrayMap<T> ddx_arr(transformed_ddx.data<T>(), C, sample_size);
Tensor x_sub_mean_mul_invstd;
x_sub_mean_mul_invstd.Resize({C, sample_size});
x_sub_mean_mul_invstd.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> x_sub_mean_mul_invstd_arr(
x_sub_mean_mul_invstd.mutable_data<T>(ctx.GetPlace()), C, sample_size);
x_sub_mean_mul_invstd_arr = (x_arr - mean_tile_data) * inv_var_tile_data;
if (dX) {
dX->mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> dx_arr(transformed_dx.mutable_data<T>(ctx.GetPlace()), C,
sample_size);
dx_arr.setZero();
if (use_global_stats) {
// math: dx = (ddscale * dy) * inv_var
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
dx_arr = dy_arr * ddscale_tile_data * inv_var_tile_data;
}
} else {
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
// np.sum(dy, axis=(n,h,w)) -
// np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x -
// mean),
// axis=(n,h,w)) * inv_var.pow(2) *
// np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) /
// NxHxW *
// np.sum(ddx * (x - mean)) *
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w)))
if (ddX) {
dx_arr +=
(x_sub_mean_mul_invstd_arr * inv_var_tile_data *
inv_var_tile_data / sample_size)
.colwise() *
(ddx_arr.rowwise().sum() * dy_arr.rowwise().sum() / sample_size -
(dy_arr * ddx_arr).rowwise().sum() +
3. * (dy_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() *
(ddx_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size);
dx_arr += (inv_var_tile_data * inv_var_tile_data).colwise() *
(ddx_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size *
(dy_arr.rowwise().sum() / sample_size - dy_arr);
dx_arr += (inv_var_tile_data * inv_var_tile_data).colwise() *
(dy_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size *
(ddx_arr.rowwise().sum() / sample_size - ddx_arr);
dx_arr = scale_tile_data * dx_arr;
}
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
dx_arr += (dy_arr * inv_var_tile_data -
(dy_arr.rowwise().sum().replicate(1, sample_size) /
sample_size) *
inv_var_tile_data -
x_sub_mean_mul_invstd_arr * inv_var_tile_data *
(dy_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size) *
ddscale_tile_data;
}
}
if (data_layout == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
TransToChannelFirst<paddle::platform::CPUDeviceContext, T>(
ctx, &transformed_dx, dX);
}
}
if (dScale) {
dScale->mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> dscale_arr(dScale->mutable_data<T>(ctx.GetPlace()),
C);
dscale_arr.setZero();
if (use_global_stats) {
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
if (ddX) {
dscale_arr = (ddx_arr * dy_arr * inv_var_tile_data).rowwise().sum();
}
} else {
// math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) *
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx
if (ddX) {
Tensor first_grad;
first_grad.Resize({C, sample_size});
EigenArrayMap<T> first_grad_arr(
first_grad.mutable_data<T>(ctx.GetPlace()), C, sample_size);
first_grad_arr.setZero();
first_grad_arr +=
inv_var_tile_data *
(dy_arr -
dy_arr.rowwise().sum().replicate(1, sample_size) / sample_size -
x_sub_mean_mul_invstd_arr *
(dy_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size);
dscale_arr = (first_grad_arr * ddx_arr).rowwise().sum();
}
}
}
if (ddY) {
ddY->mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> ddy_arr(transformed_ddy.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
ddy_arr.setZero();
if (use_global_stats) {
// math: ddy = r * ddx * inv_var + ddbias +
// ddscale * (x - mean) * inv_var
if (ddX) {
ddy_arr = scale_tile_data * ddx_arr * inv_var_tile_data;
}
} else {
// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w)))
if (ddX) {
ddy_arr +=
scale_tile_data * inv_var_tile_data *
(ddx_arr -
ddx_arr.rowwise().sum().replicate(1, sample_size) / sample_size -
x_sub_mean_mul_invstd_arr *
(ddx_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size);
}
}
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data;
}
if (ddBias) {
ConstEigenVectorArrayMap<T> ddbias_arr(ddBias->data<T>(), C);
Tensor ddbias_tile;
ddbias_tile.Resize({C, sample_size});
EigenArrayMap<T> ddbias_tile_data(
ddbias_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddbias_tile_data = ddbias_arr.replicate(1, sample_size);
ddy_arr += ddbias_tile_data;
}
if (data_layout == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
TransToChannelFirst<paddle::platform::CPUDeviceContext, T>(
ctx, &transformed_ddy, ddY);
}
}
}
};
DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -839,7 +1247,11 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, ...@@ -839,7 +1247,11 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormOpInferVarType,
ops::BatchNormGradMaker<paddle::framework::OpDesc>, ops::BatchNormGradMaker<paddle::framework::OpDesc>,
ops::BatchNormGradMaker<paddle::imperative::OpBase>); ops::BatchNormGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp,
ops::BatchNormDoubleGradMaker<paddle::framework::OpDesc>,
ops::BatchNormDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_norm_grad_grad, ops::BatchNormDoubleGradOp,
ops::BatchNormDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>, batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -848,3 +1260,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -848,3 +1260,7 @@ REGISTER_OP_CPU_KERNEL(
batch_norm_grad, batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>, ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>); ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -840,6 +841,45 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -840,6 +841,45 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
}; };
template <typename T>
class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *X = ctx.Input<Tensor>("X");
const auto *Scale = ctx.Input<Tensor>("Scale");
const auto *dY = ctx.Input<Tensor>("DY");
const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
const double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *ddX = ctx.Input<Tensor>("DDX");
const auto *ddScale = ctx.Input<Tensor>("DDScale");
const auto *ddBias = ctx.Input<Tensor>("DDBias");
auto *dX = ctx.Output<Tensor>("DX");
auto *dScale = ctx.Output<Tensor>("DScale");
auto *ddY = ctx.Output<Tensor>("DDY");
NormDoubleGradFunctor<platform::CUDADeviceContext, T>(
ctx, data_layout, X, Scale, dY, Saved_mean, Saved_variance, epsilon,
use_global_stats, ddX, ddScale, ddBias, dX, dScale, ddY);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -853,3 +893,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -853,3 +893,7 @@ REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>, batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>, ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>); ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, double>);
...@@ -103,6 +103,42 @@ inline void TransToChannelFirst(const framework::ExecutionContext& context, ...@@ -103,6 +103,42 @@ inline void TransToChannelFirst(const framework::ExecutionContext& context,
} }
} }
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context, inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) { const Tensor* input, Tensor* transformed_input) {
...@@ -154,6 +190,16 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -154,6 +190,16 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
const framework::OpKernelType& expected_kernel_type) const override; const framework::OpKernelType& expected_kernel_type) const override;
}; };
class BatchNormDoubleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override; void Make() override;
...@@ -168,6 +214,15 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -168,6 +214,15 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override; void Apply(GradOpPtr<T> op) const override;
}; };
template <typename T>
class BatchNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override;
};
class BatchNormOpInferVarType class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
...@@ -190,5 +245,11 @@ class BatchNormGradKernel : public framework::OpKernel<T> { ...@@ -190,5 +245,11 @@ class BatchNormGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
}; };
template <typename DeviceContext, typename T>
class BatchNormDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -475,11 +475,12 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T> ...@@ -475,11 +475,12 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
// (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW * // (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW *
// np.sum(dy, // np.sum(dy,
// axis=(h,w)) * (x - mean) * // axis=(h,w)) * (x - mean) *
// (np.mean(ddx, axis=(h,w)) - ddx) + ddr * (dy * inv_var - inv_var // (np.mean(ddx, axis=(h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var
// * // *
// np.mean(dy, axis=(h,w)) - // np.mean(dy, axis=(h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), // inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(h,w)))) // axis=(h,w)))
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> set_constant; math::SetConstant<platform::CPUDeviceContext, T> set_constant;
...@@ -553,9 +554,13 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T> ...@@ -553,9 +554,13 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
first_grad_arr += first_grad_arr +=
inv_var_tile_data * inv_var_tile_data *
(dy_arr - dy_arr.colwise().sum() / sample_size - (dy_arr -
dy_arr.colwise().sum().replicate(sample_size, 1) / sample_size -
x_sub_mean_mul_invstd_arr * x_sub_mean_mul_invstd_arr *
(dy_arr * x_sub_mean_mul_invstd_arr).colwise().sum() / (dy_arr * x_sub_mean_mul_invstd_arr)
.colwise()
.sum()
.replicate(sample_size, 1) /
sample_size); sample_size);
first_grad_arr = first_grad_arr * ddx_arr; first_grad_arr = first_grad_arr * ddx_arr;
for (int nc = 0; nc < NxC; ++nc) { for (int nc = 0; nc < NxC; ++nc) {
......
...@@ -467,8 +467,8 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean, ...@@ -467,8 +467,8 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
if (ddscale != nullptr) { if (ddscale != nullptr) {
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int i = beg_idx; i < end_idx; i += BlockDim) {
dx[i] += (dy[i] * var_val - dy_sum_val / sample_size * var_val - dx[i] += (dy[i] * var_val - dy_sum_val / sample_size * var_val -
(x[i] - mean_val) * var_val * dy_mul_x_sub_mean_sum_val * (x[i] - mean_val) * var_val * var_val *
var_val / sample_size) * dy_mul_x_sub_mean_sum_val * var_val / sample_size) *
ddscale[c]; ddscale[c];
} }
} }
......
此差异已折叠。
...@@ -49,5 +49,103 @@ class TestInstanceNormDoubleGradCheck(unittest.TestCase): ...@@ -49,5 +49,103 @@ class TestInstanceNormDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestBatchNormDoubleGradCheck(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
dtype = "float32"
eps = 0.005
atol = 2e-4
x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x')
z = fluid.layers.batch_norm(
input=x,
data_layout=self.data_layout,
use_global_stats=self.use_global_stats)
x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype)
gradient_checker.double_grad_check(
[x], z, x_init=x_arr, atol=atol, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestBatchNormDoubleGradCheckCase1(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NHWC'
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase2(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase3(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NHWC'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase4(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = False
self.shape = [2, 2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase5(TestBatchNormDoubleGradCheck):
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
dtype = "float32"
eps = 0.005
atol = 2e-4
chn = self.shape[1] if self.data_layout == 'NCHW' else self.shape[
-1]
x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x')
z = fluid.layers.batch_norm(
input=x,
data_layout=self.data_layout,
use_global_stats=self.use_global_stats)
x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype)
w, b = prog.global_block().all_parameters()[1:3]
w_arr = np.ones(chn).astype(dtype)
b_arr = np.zeros(chn).astype(dtype)
gradient_checker.double_grad_check(
[x, w, b],
z,
x_init=[x_arr, w_arr, b_arr],
atol=atol,
place=place,
eps=eps)
class TestBatchNormDoubleGradCheckCase6(TestBatchNormDoubleGradCheckCase5):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册