未验证 提交 52be62c5 编写于 作者: C ceci3 提交者: GitHub

fix instance norm in dy (#24717)

* fix bn & in in dy, test=develop

* update instance_norm,test=develop

* fix bugs,test=develop

* add more case in unittest,test=develop

* fix,test=develop

* fix,test=develop
上级 619848fa
...@@ -24,8 +24,6 @@ namespace operators { ...@@ -24,8 +24,6 @@ namespace operators {
void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNorm"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNorm");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InstanceNorm");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "InstanceNorm");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "InstanceNorm"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "InstanceNorm");
OP_INOUT_CHECK(ctx->HasOutput("SavedMean"), "Output", "SavedMean", OP_INOUT_CHECK(ctx->HasOutput("SavedMean"), "Output", "SavedMean",
"InstanceNorm"); "InstanceNorm");
...@@ -51,37 +49,45 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -51,37 +49,45 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const {
auto C = x_dims[1]; auto C = x_dims[1];
auto NxC = N * C; auto NxC = N * C;
auto scale_dim = ctx->GetInputDim("Scale"); if (ctx->HasInput("Scale")) {
auto bias_dim = ctx->GetInputDim("Bias"); auto scale_dim = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_dim.size(), 1UL, scale_dim.size(), 1UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"ShapeError: the dimension of scale must equal to 1." "ShapeError: the dimension of scale must equal to 1."
"But received: the shape of scale is [%s], the dimension " "But received: the shape of scale is [%s], the dimension "
"of scale is [%d]", "of scale is [%d]",
scale_dim, scale_dim.size())); scale_dim, scale_dim.size()));
PADDLE_ENFORCE_EQ(bias_dim.size(), 1UL,
platform::errors::InvalidArgument( bool check = !((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0));
"ShapeError: the dimension of bias must equal to 1."
"But received: the shape of bias is [%s],the dimension " if (check) {
"of bias is [%d]", PADDLE_ENFORCE_EQ(scale_dim[0], C,
bias_dim, bias_dim.size())); platform::errors::InvalidArgument(
"ShapeError: the shape of scale must equal to [%d]"
bool check = !((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 || "But received: the shape of scale is [%d]",
framework::product(bias_dim) <= 0)); C, scale_dim[0]));
}
if (check) { }
PADDLE_ENFORCE_EQ(scale_dim[0], C, if (ctx->HasInput("Bias")) {
platform::errors::InvalidArgument( auto bias_dim = ctx->GetInputDim("Bias");
"ShapeError: the shape of scale must equal to [%d]" PADDLE_ENFORCE_EQ(
"But received: the shape of scale is [%d]", bias_dim.size(), 1UL,
C, scale_dim[0])); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(bias_dim[0], C, "ShapeError: the dimension of bias must equal to 1."
platform::errors::InvalidArgument( "But received: the shape of bias is [%s],the dimension "
"ShapeError: the shape of bias must equal to [%d]" "of bias is [%d]",
"But received: the shape of bias is [%d]", bias_dim, bias_dim.size()));
C, bias_dim[0]));
bool check = !((!ctx->IsRuntime()) && (framework::product(bias_dim) <= 0));
if (check) {
PADDLE_ENFORCE_EQ(bias_dim[0], C,
platform::errors::InvalidArgument(
"ShapeError: the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]",
C, bias_dim[0]));
}
} }
ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("Y", x_dims);
...@@ -100,12 +106,16 @@ framework::OpKernelType InstanceNormOp::GetExpectedKernelType( ...@@ -100,12 +106,16 @@ framework::OpKernelType InstanceNormOp::GetExpectedKernelType(
if (input_data_type == framework::proto::VarType::FP64) { if (input_data_type == framework::proto::VarType::FP64) {
in_param_type = framework::proto::VarType::FP64; in_param_type = framework::proto::VarType::FP64;
} }
PADDLE_ENFORCE_EQ( if (ctx.HasInput("Scale")) {
in_param_type, ctx.Input<Tensor>("Scale")->type(), PADDLE_ENFORCE_EQ(in_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument("Scale input should be of float type")); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ( "Scale input should be of float type"));
in_param_type, ctx.Input<Tensor>("Bias")->type(), }
platform::errors::InvalidArgument("Bias input should be of float type")); if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(in_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -121,10 +131,12 @@ void InstanceNormOpMaker::Make() { ...@@ -121,10 +131,12 @@ void InstanceNormOpMaker::Make() {
AddInput("X", "The input tensor"); AddInput("X", "The input tensor");
AddInput("Scale", AddInput("Scale",
"Scale is a 1-dimensional tensor of size C " "Scale is a 1-dimensional tensor of size C "
"that is applied to the output"); "that is applied to the output")
.AsDispensable();
AddInput("Bias", AddInput("Bias",
"Bias is a 1-dimensional tensor of size C " "Bias is a 1-dimensional tensor of size C "
"that is applied to the output"); "that is applied to the output")
.AsDispensable();
AddOutput("Y", "result after normalization"); AddOutput("Y", "result after normalization");
AddOutput("SavedMean", AddOutput("SavedMean",
"Mean of the current mini batch, " "Mean of the current mini batch, "
...@@ -199,9 +211,26 @@ class InstanceNormKernel<platform::CPUDeviceContext, T> ...@@ -199,9 +211,26 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
auto scale_e = framework::EigenVector<T>::Flatten(*scale);
Tensor scale_data;
Tensor bias_data;
if (!scale) {
scale_data.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &scale_data, static_cast<T>(1));
}
if (!bias) {
bias_data.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &bias_data, static_cast<T>(0));
}
auto scale_e = scale
? framework::EigenVector<T>::Flatten(*scale)
: framework::EigenVector<T>::Flatten(
const_cast<const framework::Tensor &>(scale_data));
auto scale_arr = scale_e.reshape(C_shape); auto scale_arr = scale_e.reshape(C_shape);
auto bias_e = framework::EigenVector<T>::Flatten(*bias); auto bias_e = bias ? framework::EigenVector<T>::Flatten(*bias)
: framework::EigenVector<T>::Flatten(
const_cast<const framework::Tensor &>(bias_data));
auto bias_arr = bias_e.reshape(C_shape); auto bias_arr = bias_e.reshape(C_shape);
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
...@@ -219,7 +248,6 @@ class InstanceNormKernel<platform::CPUDeviceContext, T> ...@@ -219,7 +248,6 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
void InstanceNormGradOp::InferShape(framework::InferShapeContext *ctx) const { void InstanceNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNormGrad"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNormGrad");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InstanceNormGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input", OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
framework::GradVarName("Y"), "InstanceNormGrad"); framework::GradVarName("Y"), "InstanceNormGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
...@@ -230,15 +258,13 @@ void InstanceNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -230,15 +258,13 @@ void InstanceNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
// check output // check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "InstanceNormGrad"); framework::GradVarName("X"), "InstanceNormGrad");
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Bias")), "Output",
framework::GradVarName("Bias"), "InstanceNormGrad");
}
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
const int C = x_dims[1]; const int C = x_dims[1];
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
if (ctx->HasOutput(framework::GradVarName("Scale"))) { if (ctx->HasOutput(framework::GradVarName("Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
}
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
} }
} }
...@@ -299,7 +325,18 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T> ...@@ -299,7 +325,18 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
Eigen::DSizes<int, 2> param_shape(N, C); Eigen::DSizes<int, 2> param_shape(N, C);
Eigen::DSizes<int, 2> shape(NxC, sample_size); Eigen::DSizes<int, 2> shape(NxC, sample_size);
auto scale_e = framework::EigenVector<T>::Flatten(*scale); math::SetConstant<platform::CPUDeviceContext, T> set_constant;
Tensor scale_data;
if (!scale) {
scale_data.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &scale_data, static_cast<T>(1));
}
auto scale_e = scale
? framework::EigenVector<T>::Flatten(*scale)
: framework::EigenVector<T>::Flatten(
const_cast<const framework::Tensor &>(scale_data));
auto mean_e = framework::EigenVector<T>::Flatten(*saved_mean); auto mean_e = framework::EigenVector<T>::Flatten(*saved_mean);
auto inv_var_e = framework::EigenVector<T>::Flatten(*saved_inv_variance); auto inv_var_e = framework::EigenVector<T>::Flatten(*saved_inv_variance);
auto dy_e = framework::EigenVector<T>::Flatten(*d_y); auto dy_e = framework::EigenVector<T>::Flatten(*d_y);
...@@ -314,7 +351,6 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T> ...@@ -314,7 +351,6 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
auto tmp = (x_arr - mean_arr.eval().broadcast(bcast)) * auto tmp = (x_arr - mean_arr.eval().broadcast(bcast)) *
inv_var_arr.eval().broadcast(bcast); inv_var_arr.eval().broadcast(bcast);
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
// math: d_bias = np.sum(d_y, axis=(n,h,w)) // math: d_bias = np.sum(d_y, axis=(n,h,w))
// math: d_scale = np.sum((X-mean) / inv_std * dy, axis=(n, h,w)) // math: d_scale = np.sum((X-mean) / inv_std * dy, axis=(n, h,w))
if (d_scale && d_bias) { if (d_scale && d_bias) {
...@@ -324,8 +360,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T> ...@@ -324,8 +360,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
set_constant(dev_ctx, d_bias, static_cast<T>(0)); set_constant(dev_ctx, d_bias, static_cast<T>(0));
auto d_scale_e = framework::EigenVector<T>::Flatten(*d_scale); auto d_scale_e = framework::EigenVector<T>::Flatten(*d_scale);
auto d_bias_e = framework::EigenVector<T>::Flatten(*d_bias);
auto d_scale_data = d_scale_e.reshape(C_shape); auto d_scale_data = d_scale_e.reshape(C_shape);
auto d_bias_e = framework::EigenVector<T>::Flatten(*d_bias);
auto d_bias_data = d_bias_e.reshape(C_shape); auto d_bias_data = d_bias_e.reshape(C_shape);
d_bias_data.device(*place) = d_bias_data.device(*place) =
dy_arr.sum(mean_rdims).reshape(param_shape).sum(rdims); dy_arr.sum(mean_rdims).reshape(param_shape).sum(rdims);
...@@ -360,8 +396,6 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T> ...@@ -360,8 +396,6 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
void InstanceNormDoubleGradOp::InferShape( void InstanceNormDoubleGradOp::InferShape(
framework::InferShapeContext *ctx) const { framework::InferShapeContext *ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNormDoubleGrad"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"InstanceNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
"InstanceNormDoubleGrad"); "InstanceNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance", OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
...@@ -426,6 +460,9 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T> ...@@ -426,6 +460,9 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
auto *dScale = ctx.Output<Tensor>("DScale"); auto *dScale = ctx.Output<Tensor>("DScale");
auto *ddY = ctx.Output<Tensor>("DDY"); auto *ddY = ctx.Output<Tensor>("DDY");
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
const auto &x_dims = X->dims(); const auto &x_dims = X->dims();
int N, C, H, W, D; int N, C, H, W, D;
ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D);
...@@ -455,7 +492,13 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T> ...@@ -455,7 +492,13 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
mean_tile_data = mean_arr.transpose().replicate(sample_size, 1); mean_tile_data = mean_arr.transpose().replicate(sample_size, 1);
inv_var_tile_data = inv_var_arr.transpose().replicate(sample_size, 1); inv_var_tile_data = inv_var_arr.transpose().replicate(sample_size, 1);
ConstEigenVectorArrayMap<T> scale_arr(Scale->data<T>(), C); Tensor Scale_data;
if (!Scale) {
Scale_data.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &Scale_data, static_cast<T>(1));
}
ConstEigenVectorArrayMap<T> scale_arr(
Scale ? Scale->data<T>() : Scale_data.data<T>(), C);
Tensor scale_tile; Tensor scale_tile;
scale_tile.Resize({sample_size, NxC}); scale_tile.Resize({sample_size, NxC});
...@@ -483,9 +526,6 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T> ...@@ -483,9 +526,6 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), // inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(h,w)))) // axis=(h,w))))
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
Tensor x_sub_mean_mul_invstd; Tensor x_sub_mean_mul_invstd;
x_sub_mean_mul_invstd.Resize({sample_size, NxC}); x_sub_mean_mul_invstd.Resize({sample_size, NxC});
x_sub_mean_mul_invstd.mutable_data<T>(ctx.GetPlace()); x_sub_mean_mul_invstd.mutable_data<T>(ctx.GetPlace());
......
...@@ -146,10 +146,19 @@ class InstanceNormKernel<platform::CUDADeviceContext, T> ...@@ -146,10 +146,19 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
const int max_blocks = std::max(max_threads / block, 1); const int max_blocks = std::max(max_threads / block, 1);
const int grid = std::min((NxC + block - 1) / block, max_blocks); const int grid = std::min((NxC + block - 1) / block, max_blocks);
repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>( math::SetConstant<platform::CUDADeviceContext, T> set_constant;
scale->data<T>(), scale_tmp.data<T>(), N, C); if (scale) {
repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>( repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
bias->data<T>(), bias_tmp.data<T>(), N, C); scale->data<T>(), scale_tmp.data<T>(), N, C);
} else {
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
}
if (bias) {
repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
bias->data<T>(), bias_tmp.data<T>(), N, C);
} else {
set_constant(dev_ctx, &bias_tmp, static_cast<T>(0));
}
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
...@@ -267,24 +276,27 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T> ...@@ -267,24 +276,27 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
d_scale->mutable_data<T>(ctx.GetPlace()); d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<T>(ctx.GetPlace());
} }
PADDLE_ENFORCE_EQ( if (scale) {
scale->dims().size(), 1UL, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( scale->dims().size(), 1UL,
"The `shape` in InstanceNormOp is invalid: " platform::errors::InvalidArgument(
"the size of scale's dimensions must be equal to 1. But " "The `shape` in InstanceNormOp is invalid: "
"received: the size of scale's dimensions" "the size of scale's dimensions must be equal to 1. But "
"is [%d]", "received: the size of scale's dimensions"
scale->dims().size())); "is [%d]",
PADDLE_ENFORCE_EQ(scale->dims()[0], C, scale->dims().size()));
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(scale->dims()[0], C,
"The `shape` in InstanceNormOp is invalid: " platform::errors::InvalidArgument(
"the first dimension of scale must be equal to " "The `shape` in InstanceNormOp is invalid: "
"Channels([%d]). But received: " "the first dimension of scale must be equal to "
"the first dimension of scale is [%d]," "Channels([%d]). But received: "
"the dimensions of scale is [%s], ", "the first dimension of scale is [%d],"
C, scale->dims()[0], scale->dims())); "the dimensions of scale is [%s], ",
C, scale->dims()[0], scale->dims()));
}
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
const int n = x->numel(); const int n = x->numel();
const int block = 512; const int block = 512;
...@@ -300,8 +312,12 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T> ...@@ -300,8 +312,12 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx); ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx);
Tensor d_bias_tmp = Tensor d_bias_tmp =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx); ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx);
repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>( if (scale) {
scale->data<T>(), scale_tmp.data<T>(), N, C); repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
scale->data<T>(), scale_tmp.data<T>(), N, C);
} else {
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
}
std::vector<int> dims; std::vector<int> dims;
std::vector<int> strides; std::vector<int> strides;
...@@ -361,7 +377,7 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T> ...@@ -361,7 +377,7 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
} else { } else {
if (d_x) { if (d_x) {
GradComputeDX<T, block><<<NxC, block, 0, dev_ctx.stream()>>>( GradComputeDX<T, block><<<NxC, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(), d_y->data<T>(), scale_tmp.data<BatchNormParamType<T>>(),
saved_mean_data, x->data<T>(), saved_var_data, C, H * W * D, saved_mean_data, x->data<T>(), saved_var_data, C, H * W * D,
d_x->data<T>()); d_x->data<T>());
} }
...@@ -610,7 +626,6 @@ class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T> ...@@ -610,7 +626,6 @@ class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T>
auto *ddY = ctx.Output<Tensor>("DDY"); auto *ddY = ctx.Output<Tensor>("DDY");
const T *x_data = X->data<T>(); const T *x_data = X->data<T>();
const T *scale_data = Scale->data<T>();
const T *dy_data = dY->data<T>(); const T *dy_data = dY->data<T>();
const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data<T>()); const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data<T>());
...@@ -620,6 +635,9 @@ class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T> ...@@ -620,6 +635,9 @@ class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T>
const T *mean_data = Saved_mean->data<T>(); const T *mean_data = Saved_mean->data<T>();
const T *variance_data = Saved_variance->data<T>(); const T *variance_data = Saved_variance->data<T>();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto &x_dims = X->dims(); auto &x_dims = X->dims();
int N, C, H, W, D; int N, C, H, W, D;
ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D);
...@@ -627,15 +645,19 @@ class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T> ...@@ -627,15 +645,19 @@ class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T>
const int n = X->numel(); const int n = X->numel();
int sample_size = n / N / C; int sample_size = n / N / C;
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); Tensor scale_tmp;
if (!Scale) {
scale_tmp.mutable_data<T>({C}, ctx.GetPlace());
set_zero(dev_ctx, &scale_tmp, static_cast<T>(1));
}
const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
const int block = 512; const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1); const int max_blocks = std::max(max_threads / block, 1);
const int grid = NxC; const int grid = NxC;
const int grid1 = (C + block - 1) / block; const int grid1 = (C + block - 1) / block;
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
if (dX) { if (dX) {
T *dx_data = dX->mutable_data<T>(ctx.GetPlace()); T *dx_data = dX->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, dX, static_cast<T>(0)); set_zero(dev_ctx, dX, static_cast<T>(0));
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
// need to manually specify them in this map. // need to manually specify them in this map.
std::map<std::string, std::set<std::string>> op_ins_map = { std::map<std::string, std::set<std::string>> op_ins_map = {
{"layer_norm", {"X", "Scale", "Bias"}}, {"layer_norm", {"X", "Scale", "Bias"}},
{"instance_norm", {"X", "Scale", "Bias"}},
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}}, {"label_smooth", {"X", "PriorDist"}},
{"assign", {"X"}}, {"assign", {"X"}},
......
...@@ -1028,16 +1028,16 @@ class InstanceNorm(layers.Layer): ...@@ -1028,16 +1028,16 @@ class InstanceNorm(layers.Layer):
num_channels(int): Indicate the number of channels of the input ``Tensor``. num_channels(int): Indicate the number of channels of the input ``Tensor``.
epsilon(float, optional): A value added to the denominator for epsilon(float, optional): A value added to the denominator for
numerical stability. Default is 1e-5. numerical stability. Default is 1e-5.
param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale` param_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm
will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. will create ParamAttr as param_attr, the name of scale can be set in ParamAttr.
If the Initializer of the param_attr is not set, the parameter is initialized If the Initializer of the param_attr is not set, the parameter is initialized
one. Default: None. one. If it is set to False, will not create param_attr. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the bias of instance_norm. bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of instance_norm.
If it is set to None or one attribute of ParamAttr, instance_norm If it is set to None or one attribute of ParamAttr, instance_norm
will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
If the Initializer of the bias_attr is not set, the bias is initialized zero. If the Initializer of the bias_attr is not set, the bias is initialized zero.
Default: None. If it is set to False, will not create bias_attr. Default: None.
dtype(str, optional): Indicate the data type of the input ``Tensor``, dtype(str, optional): Indicate the data type of the input ``Tensor``,
which can be float32 or float64. Default: float32. which can be float32 or float64. Default: float32.
...@@ -1071,25 +1071,30 @@ class InstanceNorm(layers.Layer): ...@@ -1071,25 +1071,30 @@ class InstanceNorm(layers.Layer):
bias_attr=None, bias_attr=None,
dtype='float32'): dtype='float32'):
super(InstanceNorm, self).__init__() super(InstanceNorm, self).__init__()
assert bias_attr is not False, "bias_attr should not be False in InstanceNorm."
if param_attr == False or bias_attr == False:
assert bias_attr == param_attr, "param_attr and bias_attr must be set to Fasle at the same time in InstanceNorm"
self._epsilon = epsilon self._epsilon = epsilon
self._param_attr = param_attr self._param_attr = param_attr
self._bias_attr = bias_attr self._bias_attr = bias_attr
self._dtype = dtype self._dtype = dtype
self.scale = self.create_parameter( if param_attr != False and bias_attr != False:
attr=self._param_attr, self.scale = self.create_parameter(
shape=[num_channels], attr=self._param_attr,
dtype=self._dtype, shape=[num_channels],
default_initializer=Constant(1.0), dtype=self._dtype,
is_bias=False) default_initializer=Constant(1.0),
self.bias = self.create_parameter( is_bias=False)
attr=self._bias_attr, self.bias = self.create_parameter(
shape=[num_channels], attr=self._bias_attr,
dtype=self._dtype, shape=[num_channels],
default_initializer=Constant(0.0), dtype=self._dtype,
is_bias=True) default_initializer=Constant(0.0),
is_bias=True)
else:
self.scale = None
self.bias = None
def forward(self, input): def forward(self, input):
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -1102,7 +1107,10 @@ class InstanceNorm(layers.Layer): ...@@ -1102,7 +1107,10 @@ class InstanceNorm(layers.Layer):
attrs = {"epsilon": self._epsilon} attrs = {"epsilon": self._epsilon}
inputs = {"X": [input], "Scale": [self.scale], "Bias": [self.bias]} if self.scale and self.bias:
inputs = {"X": [input], "Scale": [self.scale], "Bias": [self.bias]}
else:
inputs = {"X": [input]}
saved_mean = self._helper.create_variable_for_type_inference( saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True) dtype=self._dtype, stop_gradient=True)
......
...@@ -3114,15 +3114,17 @@ def instance_norm(input, ...@@ -3114,15 +3114,17 @@ def instance_norm(input,
The data type is float32 or float64. The data type is float32 or float64.
epsilon(float, Default 1e-05): A value added to the denominator for epsilon(float, Default 1e-05): A value added to the denominator for
numerical stability. Default is 1e-5. numerical stability. Default is 1e-5.
param_attr(ParamAttr|None): The parameter attribute for Parameter `scale` param_attr(ParamAttr|None|bool, optional): The parameter attribute for Parameter `scale`
of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm
will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. will create ParamAttr as param_attr, the name of scale can be set in ParamAttr.
If the Initializer of the param_attr is not set, the parameter is initialized If the Initializer of the param_attr is not set, the parameter is initialized
with Xavier. Default: None. with Xavier. If the param_attr is set to False, instance_norm will not create param_attr.
bias_attr(ParamAttr|None): The parameter attribute for the bias of instance_norm. Default: None.
bias_attr(ParamAttr|None|bool, optional): The parameter attribute for the bias of instance_norm.
If it is set to None or one attribute of ParamAttr, instance_norm If it is set to None or one attribute of ParamAttr, instance_norm
will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
If the Initializer of the bias_attr is not set, the bias is initialized zero. If the Initializer of the bias_attr is not set, the bias is initialized zero.
If the bias_attr is set to False, instance_norm will not create bias_attr.
Default: None. Default: None.
name(string, Default None): A name for this layer(optional). If set None, the layer name(string, Default None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
...@@ -3142,7 +3144,9 @@ def instance_norm(input, ...@@ -3142,7 +3144,9 @@ def instance_norm(input,
""" """
check_variable_and_dtype(input, 'input', ['float32', 'float64'], check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'instance_norm') 'instance_norm')
assert bias_attr is not False, "bias_attr should not be False in instance_norm." if param_attr is False:
assert bias_attr is False, "param_attr and bias_attr must be set to Fasle at the same time in instance_norm"
helper = LayerHelper('instance_norm', **locals()) helper = LayerHelper('instance_norm', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
...@@ -3155,18 +3159,19 @@ def instance_norm(input, ...@@ -3155,18 +3159,19 @@ def instance_norm(input,
param_shape = [channel_num] param_shape = [channel_num]
# create parameter if param_attr and bias_attr:
scale = helper.create_parameter( # create parameter
attr=helper.param_attr, scale = helper.create_parameter(
shape=param_shape, attr=helper.param_attr,
dtype=dtype, shape=param_shape,
default_initializer=Constant(1.0)) dtype=dtype,
bias = helper.create_parameter( default_initializer=Constant(1.0))
attr=helper.bias_attr, bias = helper.create_parameter(
shape=param_shape, attr=helper.bias_attr,
dtype=dtype, shape=param_shape,
is_bias=True, dtype=dtype,
default_initializer=Constant(0.0)) is_bias=True,
default_initializer=Constant(0.0))
# create output # create output
saved_mean = helper.create_variable_for_type_inference( saved_mean = helper.create_variable_for_type_inference(
...@@ -3176,13 +3181,14 @@ def instance_norm(input, ...@@ -3176,13 +3181,14 @@ def instance_norm(input,
instance_norm_out = helper.create_variable_for_type_inference(dtype) instance_norm_out = helper.create_variable_for_type_inference(dtype)
inputs = {"X": input}
if param_attr and bias_attr:
inputs["Scale"] = scale
inputs["Bias"] = bias
helper.append_op( helper.append_op(
type="instance_norm", type="instance_norm",
inputs={ inputs=inputs,
"X": input,
"Scale": scale,
"Bias": bias,
},
outputs={ outputs={
"Y": instance_norm_out, "Y": instance_norm_out,
"SavedMean": saved_mean, "SavedMean": saved_mean,
......
...@@ -20,6 +20,7 @@ import paddle.fluid as fluid ...@@ -20,6 +20,7 @@ import paddle.fluid as fluid
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from op_test import OpTest from op_test import OpTest
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from paddle.fluid.dygraph import to_variable
def _reference_instance_norm_naive(x, scale, bias, epsilon, mean, var): def _reference_instance_norm_naive(x, scale, bias, epsilon, mean, var):
...@@ -214,5 +215,63 @@ class TestInstanceNormOpError(unittest.TestCase): ...@@ -214,5 +215,63 @@ class TestInstanceNormOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.instance_norm, x2) self.assertRaises(TypeError, fluid.layers.instance_norm, x2)
class TestElasticNormOp(unittest.TestCase):
def init_test_case(self):
self.epsilon = 1e-5
self.places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(
"instance_norm"):
self.places.append(core.CUDAPlace(0))
def test_norm(self):
self.init_test_case()
inputs = np.random.random((2, 3, 5, 5)).astype(np.float32)
shape = inputs.shape
n, c, h, w = shape[0], shape[1], shape[2], shape[3]
scale_shape = [c]
mean_shape = [n * c]
scale = np.ones(scale_shape).astype(np.float32)
bias = np.zeros(scale_shape).astype(np.float32)
mean, variance = _cal_mean_variance(inputs, self.epsilon, mean_shape)
out_np, _, _ = _reference_instance_norm_naive(
inputs, scale, bias, self.epsilon, mean, variance)
for place in self.places:
with fluid.dygraph.guard(place):
instance_norm = fluid.dygraph.InstanceNorm(
5, param_attr=False, bias_attr=False)
outputs = instance_norm(to_variable(inputs))
self.assertTrue(np.allclose(outputs.numpy(), out_np, atol=1e-6))
class TestElasticNormOpCase2(unittest.TestCase):
def init_test_case(self):
self.epsilon = 1e-5
self.places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(
"instance_norm"):
self.places.append(core.CUDAPlace(0))
def test_norm(self):
self.init_test_case()
inputs = np.random.random((2, 3, 5, 5)).astype(np.float32)
shape = inputs.shape
n, c, h, w = shape[0], shape[1], shape[2], shape[3]
scale_shape = [c]
mean_shape = [n * c]
scale = np.ones(scale_shape).astype(np.float32)
bias = np.zeros(scale_shape).astype(np.float32)
mean, variance = _cal_mean_variance(inputs, self.epsilon, mean_shape)
out_np, _, _ = _reference_instance_norm_naive(
inputs, scale, bias, self.epsilon, mean, variance)
for place in self.places:
with fluid.dygraph.guard(place):
instance_norm = fluid.dygraph.InstanceNorm(
3, param_attr=True, bias_attr=True)
outputs = instance_norm(to_variable(inputs))
self.assertTrue(np.allclose(outputs.numpy(), out_np, atol=1e-6))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -49,5 +49,24 @@ class TestInstanceNormDoubleGradCheck(unittest.TestCase): ...@@ -49,5 +49,24 @@ class TestInstanceNormDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestInstanceNormDoubleGradCheckWithoutParamBias(
TestInstanceNormDoubleGradCheck):
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
shape = [2, 3, 4, 5]
dtype = "float32"
eps = 0.005
atol = 1e-4
x = layers.create_parameter(dtype=dtype, shape=shape, name='x')
z = fluid.layers.instance_norm(
input=x, param_attr=False, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], z, x_init=x_arr, atol=atol, place=place, eps=eps)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册