未验证 提交 1d3b27ca 编写于 作者: C ceci3 提交者: GitHub

add double grad compute for batch norm (#27296)

* add double grad compute for batch norm,test=develop

* fix unittest, test=develop

* remove unuse tensor,test=develop

* add format,test=develop

* update, test=develop
上级 4bd7aa25
......@@ -831,6 +831,401 @@ void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const {
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
}
template <typename T>
void BatchNormDoubleGradMaker<T>::Apply(GradOpPtr<T> op) const {
op->SetType("batch_norm_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("SavedMean", this->Input("SavedMean"));
op->SetInput("SavedVariance", this->Input("SavedVariance"));
if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) {
op->SetInput("Variance", this->Input("Variance"));
}
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DDScale", this->OutputGrad(framework::GradVarName("Scale")));
op->SetInput("DDBias", this->OutputGrad(framework::GradVarName("Bias")));
op->SetInput("DY", this->Input(framework::GradVarName("Y")));
op->SetAttrMap(this->Attrs());
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DScale", this->InputGrad("Scale"));
op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y")));
}
void BatchNormDoubleGradOp::InferShape(
framework::InferShapeContext *ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
"BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
"BatchNormDoubleGrad");
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "VarianceOut",
"BatchNormDoubleGrad");
}
OP_INOUT_CHECK(ctx->HasInput("DDX"), "Input", "DDX", "BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("DY"), "Input", "DY", "BatchNormDoubleGrad");
// check output
OP_INOUT_CHECK(ctx->HasOutput("DX"), "Output", "DX", "BatchNormDoubleGrad");
const auto x_dims = ctx->GetInputDim("X");
const int C = x_dims[1];
if (ctx->HasOutput("DX")) {
ctx->SetOutputDim("DX", x_dims);
}
if (ctx->HasOutput("DScale")) {
ctx->SetOutputDim("DScale", {C});
}
if (ctx->HasOutput("DDY")) {
ctx->ShareDim("X", "DDY");
}
}
framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar("DY");
if (var == nullptr) {
PADDLE_THROW(
platform::errors::NotFound("cannot find gradient variable of Y"));
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
template <typename T>
class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *X = ctx.Input<Tensor>("X");
const auto *Scale = ctx.Input<Tensor>("Scale");
const auto *dY = ctx.Input<Tensor>("DY");
const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
const float epsilon = ctx.Attr<float>("epsilon");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *ddX = ctx.Input<Tensor>("DDX");
const auto *ddScale = ctx.Input<Tensor>("DDScale");
const auto *ddBias = ctx.Input<Tensor>("DDBias");
auto *dX = ctx.Output<Tensor>("DX");
auto *dScale = ctx.Output<Tensor>("DScale");
auto *ddY = ctx.Output<Tensor>("DDY");
dX->mutable_data<T>(ctx.GetPlace());
ddY->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
const auto &x_dims = X->dims();
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int sample_size = X->numel() / C;
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
const T *mean_data = Saved_mean->data<T>();
const T *inv_var_data = Saved_variance->data<T>();
Tensor inv_var_tensor;
if (use_global_stats) {
const auto *running_variance = ctx.Input<Tensor>("Variance");
inv_var_tensor.Resize({C});
T *running_inv_var_data = inv_var_tensor.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> inv_var_tmp(running_inv_var_data, C);
ConstEigenVectorArrayMap<T> var_arr(running_variance->data<T>(), C);
inv_var_tmp = (var_arr + epsilon).sqrt().inverse();
inv_var_data = running_inv_var_data;
}
// transpose NCHW -> NHWC for easy calculate
Tensor transformed_x(X->type());
Tensor transformed_dy(dY->type());
Tensor transformed_ddx(ddX->type());
Tensor transformed_dx(dX->type());
Tensor transformed_ddy(ddY->type());
if (data_layout == DataLayout::kNCHW && x_dims.size() > 2) {
VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
// Input Tensor
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, X,
&transformed_x);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, X, &transformed_x);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, dY,
&transformed_dy);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, dY,
&transformed_dy);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, ddX,
&transformed_ddx);
TransToChannelLast<platform::CPUDeviceContext, T>(ctx, ddX,
&transformed_ddx);
// Output Tensor
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, dX,
&transformed_dx);
ResizeToChannelLast<platform::CPUDeviceContext, T>(ctx, ddY,
&transformed_ddy);
} else {
transformed_x.ShareDataWith(*X);
transformed_dy.ShareDataWith(*dY);
transformed_ddx.ShareDataWith(*ddX);
transformed_dx.ShareDataWith(*dX);
transformed_ddy.ShareDataWith(*ddY);
}
ConstEigenArrayMap<T> x_arr(transformed_x.data<T>(), C, sample_size);
ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);
Tensor mean_tile;
mean_tile.Resize({C, sample_size});
mean_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> mean_tile_data(mean_tile.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
Tensor inv_var_tile;
inv_var_tile.Resize({C, sample_size});
inv_var_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> inv_var_tile_data(
inv_var_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
mean_tile_data = mean_arr.replicate(1, sample_size);
inv_var_tile_data = inv_var_arr.replicate(1, sample_size);
Tensor Scale_data;
if (!Scale) {
Scale_data.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &Scale_data, static_cast<T>(1));
}
ConstEigenVectorArrayMap<T> scale_arr(
Scale ? Scale->data<T>() : Scale_data.data<T>(), C);
Tensor scale_tile;
scale_tile.Resize({C, sample_size});
scale_tile.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> scale_tile_data(scale_tile.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
scale_tile_data = scale_arr.replicate(1, sample_size);
ConstEigenArrayMap<T> dy_arr(transformed_dy.data<T>(), C, sample_size);
ConstEigenArrayMap<T> ddx_arr(transformed_ddx.data<T>(), C, sample_size);
Tensor x_sub_mean_mul_invstd;
x_sub_mean_mul_invstd.Resize({C, sample_size});
x_sub_mean_mul_invstd.mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> x_sub_mean_mul_invstd_arr(
x_sub_mean_mul_invstd.mutable_data<T>(ctx.GetPlace()), C, sample_size);
x_sub_mean_mul_invstd_arr = (x_arr - mean_tile_data) * inv_var_tile_data;
if (dX) {
dX->mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> dx_arr(transformed_dx.mutable_data<T>(ctx.GetPlace()), C,
sample_size);
dx_arr.setZero();
if (use_global_stats) {
// math: dx = (ddscale * dy) * inv_var
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
dx_arr = dy_arr * ddscale_tile_data * inv_var_tile_data;
}
} else {
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
// np.sum(dy, axis=(n,h,w)) -
// np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x -
// mean),
// axis=(n,h,w)) * inv_var.pow(2) *
// np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) /
// NxHxW *
// np.sum(ddx * (x - mean)) *
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w))))
if (ddX) {
dx_arr +=
(x_sub_mean_mul_invstd_arr * inv_var_tile_data *
inv_var_tile_data / sample_size)
.colwise() *
(ddx_arr.rowwise().sum() * dy_arr.rowwise().sum() / sample_size -
(dy_arr * ddx_arr).rowwise().sum() +
3. * (dy_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() *
(ddx_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size);
dx_arr += (inv_var_tile_data * inv_var_tile_data).colwise() *
(ddx_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size *
(dy_arr.rowwise().sum() / sample_size - dy_arr);
dx_arr += (inv_var_tile_data * inv_var_tile_data).colwise() *
(dy_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() /
sample_size *
(ddx_arr.rowwise().sum() / sample_size - ddx_arr);
dx_arr = scale_tile_data * dx_arr;
}
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
dx_arr += (dy_arr * inv_var_tile_data -
(dy_arr.rowwise().sum().replicate(1, sample_size) /
sample_size) *
inv_var_tile_data -
x_sub_mean_mul_invstd_arr * inv_var_tile_data *
(dy_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size) *
ddscale_tile_data;
}
}
if (data_layout == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
TransToChannelFirst<paddle::platform::CPUDeviceContext, T>(
ctx, &transformed_dx, dX);
}
}
if (dScale) {
dScale->mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> dscale_arr(dScale->mutable_data<T>(ctx.GetPlace()),
C);
dscale_arr.setZero();
if (use_global_stats) {
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
if (ddX) {
dscale_arr = (ddx_arr * dy_arr * inv_var_tile_data).rowwise().sum();
}
} else {
// math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) *
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx
if (ddX) {
Tensor first_grad;
first_grad.Resize({C, sample_size});
EigenArrayMap<T> first_grad_arr(
first_grad.mutable_data<T>(ctx.GetPlace()), C, sample_size);
first_grad_arr.setZero();
first_grad_arr +=
inv_var_tile_data *
(dy_arr -
dy_arr.rowwise().sum().replicate(1, sample_size) / sample_size -
x_sub_mean_mul_invstd_arr *
(dy_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size);
dscale_arr = (first_grad_arr * ddx_arr).rowwise().sum();
}
}
}
if (ddY) {
ddY->mutable_data<T>(ctx.GetPlace());
EigenArrayMap<T> ddy_arr(transformed_ddy.mutable_data<T>(ctx.GetPlace()),
C, sample_size);
ddy_arr.setZero();
if (use_global_stats) {
// math: ddy = r * ddx * inv_var
if (ddX) {
ddy_arr = scale_tile_data * ddx_arr * inv_var_tile_data;
}
} else {
// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w)))
if (ddX) {
ddy_arr +=
scale_tile_data * inv_var_tile_data *
(ddx_arr -
ddx_arr.rowwise().sum().replicate(1, sample_size) / sample_size -
x_sub_mean_mul_invstd_arr *
(ddx_arr * x_sub_mean_mul_invstd_arr)
.rowwise()
.sum()
.replicate(1, sample_size) /
sample_size);
}
if (ddScale && ddBias) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
ConstEigenVectorArrayMap<T> ddbias_arr(ddBias->data<T>(), C);
Tensor ddbias_tile;
ddbias_tile.Resize({C, sample_size});
EigenArrayMap<T> ddbias_tile_data(
ddbias_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddbias_tile_data = ddbias_arr.replicate(1, sample_size);
ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data;
ddy_arr += ddbias_tile_data;
}
}
if (data_layout == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
TransToChannelFirst<paddle::platform::CPUDeviceContext, T>(
ctx, &transformed_ddy, ddY);
}
}
}
};
DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"});
} // namespace operators
} // namespace paddle
......@@ -839,7 +1234,11 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType,
ops::BatchNormGradMaker<paddle::framework::OpDesc>,
ops::BatchNormGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp,
ops::BatchNormDoubleGradMaker<paddle::framework::OpDesc>,
ops::BatchNormDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_norm_grad_grad, ops::BatchNormDoubleGradOp,
ops::BatchNormDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -848,3 +1247,7 @@ REGISTER_OP_CPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
......@@ -840,6 +841,45 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
};
template <typename T>
class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *X = ctx.Input<Tensor>("X");
const auto *Scale = ctx.Input<Tensor>("Scale");
const auto *dY = ctx.Input<Tensor>("DY");
const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
const double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *ddX = ctx.Input<Tensor>("DDX");
const auto *ddScale = ctx.Input<Tensor>("DDScale");
const auto *ddBias = ctx.Input<Tensor>("DDBias");
auto *dX = ctx.Output<Tensor>("DX");
auto *dScale = ctx.Output<Tensor>("DScale");
auto *ddY = ctx.Output<Tensor>("DDY");
NormDoubleGradFunctor<platform::CUDADeviceContext, T>(
ctx, data_layout, X, Scale, dY, Saved_mean, Saved_variance, epsilon,
use_global_stats, ddX, ddScale, ddBias, dX, dScale, ddY);
}
};
} // namespace operators
} // namespace paddle
......@@ -853,3 +893,7 @@ REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, double>);
......@@ -103,6 +103,42 @@ inline void TransToChannelFirst(const framework::ExecutionContext& context,
}
}
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) {
......@@ -154,6 +190,16 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
const framework::OpKernelType& expected_kernel_type) const override;
};
class BatchNormDoubleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
......@@ -168,6 +214,15 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override;
};
template <typename T>
class BatchNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override;
};
class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
......@@ -190,5 +245,11 @@ class BatchNormGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class BatchNormDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
} // namespace paddle
......@@ -595,9 +595,13 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
first_grad_arr +=
inv_var_tile_data *
(dy_arr - dy_arr.colwise().sum() / sample_size -
(dy_arr -
dy_arr.colwise().sum().replicate(sample_size, 1) / sample_size -
x_sub_mean_mul_invstd_arr *
(dy_arr * x_sub_mean_mul_invstd_arr).colwise().sum() /
(dy_arr * x_sub_mean_mul_invstd_arr)
.colwise()
.sum()
.replicate(sample_size, 1) /
sample_size);
first_grad_arr = first_grad_arr * ddx_arr;
for (int nc = 0; nc < NxC; ++nc) {
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
// np.sum(dy, axis=(n,h,w)) -
// np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x -
// mean),
// axis=(n,h,w)) * inv_var.pow(2) *
// np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) /
// NxHxW *
// np.sum(ddx * (x - mean)) *
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w))))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDX(const T *x, const T *mean,
const T *variance, const T *ddx,
const T *dy, const T *scale,
const T *ddscale, const int N, const int C,
const int sample_size, const double epsilon,
T *dx) {
const int outer_size = C;
const int inner_size = N * sample_size;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage dy_storage;
__shared__ typename BlockReduce::TempStorage ddx_storage;
__shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage;
__shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
__shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
__shared__ T dy_sum_val;
__shared__ T ddx_sum_val;
__shared__ T dy_mul_ddx_sum_val;
__shared__ T dy_mul_x_sub_mean_sum_val;
__shared__ T ddx_mul_x_sub_mean_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T mean_val = mean[i];
T var_val = variance[i];
T dy_sum = 0;
T ddx_sum = 0;
T dy_mul_ddx_sum = 0;
T dy_mul_x_sub_mean_sum = 0;
T ddx_mul_x_sub_mean_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
T ddx_i = ddx[index];
T dy_i = dy[index];
T tmp = x[index] - mean_val;
dy_sum += dy_i;
ddx_sum += ddx_i;
dy_mul_ddx_sum += (ddx_i * dy_i);
dy_mul_x_sub_mean_sum += (dy_i * tmp);
ddx_mul_x_sub_mean_sum += (ddx_i * tmp);
}
dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
dy_mul_ddx_sum =
BlockReduce(dy_mul_ddx_storage).Reduce(dy_mul_ddx_sum, cub::Sum());
dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
.Reduce(dy_mul_x_sub_mean_sum, cub::Sum());
ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
.Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
dy_sum_val = dy_sum;
ddx_sum_val = ddx_sum;
dy_mul_ddx_sum_val = dy_mul_ddx_sum;
dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
}
__syncthreads();
if (ddx != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
dx[index] +=
((x[index] - mean_val) * var_val * var_val * var_val / inner_size *
(ddx_sum_val * dy_sum_val / inner_size - dy_mul_ddx_sum_val +
3. * dy_mul_x_sub_mean_sum_val * var_val *
ddx_mul_x_sub_mean_sum_val * var_val / inner_size) +
ddx_mul_x_sub_mean_sum_val * var_val / inner_size * var_val *
var_val * (dy_sum_val / inner_size - dy[index]) +
dy_mul_x_sub_mean_sum_val * var_val / inner_size * var_val *
var_val * (ddx_sum_val / inner_size - ddx[index])) *
scale[i];
}
}
__syncthreads();
if (ddscale != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val -
(x[index] - mean_val) * var_val *
dy_mul_x_sub_mean_sum_val * var_val / inner_size) *
ddscale[i];
}
}
}
}
// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDDY(const T *x, const T *mean,
const T *variance, const T *ddscale,
const T *ddbias, const T *ddx,
const T *scale, const int N, const int C,
const int sample_size,
const double epsilon, T *ddy) {
const int outer_size = C;
const int inner_size = N * sample_size;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ddx_storage;
__shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
__shared__ T ddx_sum_val;
__shared__ T ddx_mul_x_sub_mean_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T mean_val = mean[i];
T var_val = variance[i];
T ddx_sum = 0;
T ddx_mul_x_sub_mean_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
T ddx_i = ddx[index];
ddx_sum += ddx_i;
ddx_mul_x_sub_mean_sum += (ddx_i * (x[index] - mean_val));
}
ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
.Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
ddx_sum_val = ddx_sum;
ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
}
__syncthreads();
if (ddx != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
ddy[index] += scale[i] * var_val *
(ddx[index] - ddx_sum_val / inner_size -
(x[index] - mean_val) * var_val *
ddx_mul_x_sub_mean_sum_val * var_val / inner_size);
}
}
__syncthreads();
if (ddscale != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
ddy[index] += (x[index] - mean_val) * var_val * ddscale[i];
}
}
__syncthreads();
if (ddbias != nullptr) {
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
ddy[index] += ddbias[i];
}
}
}
}
// math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) *
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScale(const T *x, const T *mean,
const T *variance, const T *ddx,
const T *dy, const int N, const int C,
const int sample_size,
const double epsilon, T *dscale) {
const int outer_size = C;
const int inner_size = N * sample_size;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage dy_storage;
__shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
__shared__ typename BlockReduce::TempStorage dscale_tmp_storage;
__shared__ T dy_sum_val;
__shared__ T dy_mul_x_sub_mean_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T dy_sum = 0;
T dy_mul_x_sub_mean_sum = 0;
T mean_val = mean[i];
T var_val = variance[i];
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
T dy_i = dy[index];
dy_sum += dy_i;
dy_mul_x_sub_mean_sum += (dy_i * (x[index] - mean_val));
}
dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
.Reduce(dy_mul_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
dy_sum_val = dy_sum;
dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
}
__syncthreads();
if (ddx != nullptr) {
T dscale_tmp = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
dscale_tmp += ddx[index] * var_val *
(dy[index] - dy_sum_val / inner_size -
dy_mul_x_sub_mean_sum_val * (x[index] - mean_val) *
var_val * var_val / inner_size);
}
dscale_tmp =
BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum());
if (threadIdx.x == 0) {
dscale[i] += dscale_tmp;
}
__syncthreads();
}
}
}
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScaleWithGlobal(
const T *ddx, const T *variance, const T *dy, const double epsilon,
const int N, const int C, const int sample_size, T *dscale) {
int outer_size = C;
int inner_size = N * sample_size;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ddx_mul_dy_storage;
__shared__ T ddx_mul_dy_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
T ddx_mul_dy_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index =
layout == framework::DataLayout::kNCHW
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
T ddx_i = ddx[index];
T dy_i = dy[index];
ddx_mul_dy_sum += (ddx_i * dy_i);
}
ddx_mul_dy_sum =
BlockReduce(ddx_mul_dy_storage).Reduce(ddx_mul_dy_sum, cub::Sum());
if (threadIdx.x == 0) {
ddx_mul_dy_sum_val = ddx_mul_dy_sum;
}
__syncthreads();
if (ddx != nullptr) {
dscale[i] = inv_var_i * ddx_mul_dy_sum_val;
}
}
}
// math: dx = ddscale * dy * inv_var
// math: ddy = scale * ddx * inv_var
template <typename T, framework::DataLayout layout>
__global__ void DoubleGradComputeDataWithGlobal(
const T *dy, const T *scale, const T *variance, const double epsilon,
const int C, const int sample_size, const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
if (scale != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
dx[i] = dy[i] * scale[c] * inv_var;
}
}
}
template <typename DeviceContext, typename T>
void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
const DataLayout data_layout, const Tensor *X,
const Tensor *Scale, const Tensor *dY,
const Tensor *Saved_mean,
const Tensor *Saved_variance, const double epsilon,
const bool use_global_stats, const Tensor *ddX,
const Tensor *ddScale, const Tensor *ddBias,
Tensor *dX, Tensor *dScale, Tensor *ddY) {
const T *x_data = X->data<T>();
const T *dy_data = dY->data<T>();
const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data<T>());
const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data<T>());
const T *ddbias_data = (ddBias == nullptr ? nullptr : ddBias->data<T>());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
auto &x_dims = X->dims();
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int N = x_dims[0];
const int num = X->numel();
const int sample_size = num / N / C;
Tensor scale_tmp;
if (!Scale) {
scale_tmp.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
}
const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(C, max_blocks);
int grid1 = (num + block - 1) / block;
const T *mean_data, *variance_data;
if (use_global_stats) {
const auto *running_var = ctx.Input<Tensor>("Variance");
const auto *running_var_data = running_var->template data<T>();
variance_data = running_var_data;
} else {
const T *smean_data = Saved_mean->data<T>();
const T *svariance_data = Saved_variance->data<T>();
mean_data = smean_data;
variance_data = svariance_data;
}
if (dX) {
T *dx_data = dX->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, dX, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDataWithGlobal<
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
dx_data);
} else {
DoubleGradComputeDataWithGlobal<
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
dx_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDX<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, scale_data,
ddscale_data, N, C, sample_size, epsilon, dx_data);
} else {
DoubleGradComputeDX<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, scale_data,
ddscale_data, N, C, sample_size, epsilon, dx_data);
}
}
}
if (dScale) {
T *dscale_data = dScale->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, dScale, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDScaleWithGlobal<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
ddx_data, variance_data, dy_data, epsilon, N, C, sample_size,
dscale_data);
} else {
DoubleGradComputeDScaleWithGlobal<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
ddx_data, variance_data, dy_data, epsilon, N, C, sample_size,
dscale_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDScale<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, N, C,
sample_size, epsilon, dscale_data);
} else {
DoubleGradComputeDScale<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, N, C,
sample_size, epsilon, dscale_data);
}
}
}
if (ddY) {
T *ddy_data = ddY->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, ddY, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDataWithGlobal<
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
ddx_data, scale_data, variance_data, epsilon, C, sample_size, num,
ddy_data);
} else {
DoubleGradComputeDataWithGlobal<
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
ddx_data, scale_data, variance_data, epsilon, C, sample_size, num,
ddy_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDDY<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddscale_data, ddbias_data,
ddx_data, scale_data, N, C, sample_size, epsilon, ddy_data);
} else {
DoubleGradComputeDDY<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, mean_data, variance_data, ddscale_data, ddbias_data,
ddx_data, scale_data, N, C, sample_size, epsilon, ddy_data);
}
}
}
}
} // namespace operators
} // namespace paddle
......@@ -3167,7 +3167,7 @@ def instance_norm(input,
param_shape = [channel_num]
if param_attr and bias_attr:
if param_attr != False and bias_attr != False:
# create parameter
scale = helper.create_parameter(
attr=helper.param_attr,
......@@ -3190,7 +3190,7 @@ def instance_norm(input,
instance_norm_out = helper.create_variable_for_type_inference(dtype)
inputs = {"X": input}
if param_attr and bias_attr:
if param_attr != False and bias_attr != False:
inputs["Scale"] = scale
inputs["Bias"] = bias
......
......@@ -346,7 +346,7 @@ class TestRaiseNoDoubleGradOp(TestCase):
with fluid.dygraph.guard():
x = fluid.layers.ones(shape=[2, 3, 2, 2], dtype='float32')
x.stop_gradient = False
y = paddle.fluid.layers.batch_norm(x)
y = paddle.fluid.layers.group_norm(x, groups=1)
dx = fluid.dygraph.grad(
outputs=[y], inputs=[x], create_graph=True,
......
......@@ -68,5 +68,67 @@ class TestInstanceNormDoubleGradCheckWithoutParamBias(
[x], z, x_init=x_arr, atol=atol, place=place, eps=eps)
class TestBatchNormDoubleGradCheck(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
dtype = "float32"
eps = 0.005
atol = 1e-4
x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x')
z = fluid.layers.batch_norm(
input=x,
data_layout=self.data_layout,
use_global_stats=self.use_global_stats)
x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype)
gradient_checker.double_grad_check(
[x], z, x_init=x_arr, atol=atol, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestBatchNormDoubleGradCheckCase1(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NHWC'
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase2(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase3(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NHWC'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase4(TestBatchNormDoubleGradCheck):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = False
self.shape = [2, 2, 3, 4, 5]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册