From 99c6497b2056c09d7f0fe520f68c369043d61586 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 8 Jun 2022 11:36:29 +0800 Subject: [PATCH] [Phi]Move group op kernel into PHI and add yaml / unittest (#43104) * move_group_norm * move group norm backward * fix code format * modify code according comment --- paddle/fluid/operators/group_norm_op.cc | 106 +--- paddle/fluid/operators/group_norm_op_npu.cc | 3 +- paddle/phi/infermeta/ternary.cc | 117 +++++ paddle/phi/infermeta/ternary.h | 10 + .../phi/kernels/cpu/group_norm_grad_kernel.cc | 204 ++++++++ paddle/phi/kernels/cpu/group_norm_kernel.cc | 210 ++++++++ .../phi/kernels/gpu/group_norm_grad_kernel.cu | 452 ++++++++++++++++++ paddle/phi/kernels/gpu/group_norm_kernel.cu | 233 +++++++++ paddle/phi/kernels/gpu/group_norm_utils.h | 174 +++++++ paddle/phi/kernels/group_norm_grad_kernel.h | 39 ++ paddle/phi/kernels/group_norm_kernel.h | 35 ++ paddle/phi/ops/compat/group_norm_sig.cc | 39 ++ python/paddle/fluid/dygraph/nn.py | 12 +- .../tests/unittests/test_group_norm_op.py | 27 +- .../tests/unittests/test_group_norm_op_v2.py | 9 + python/paddle/nn/functional/norm.py | 2 +- python/paddle/utils/code_gen/api.yaml | 12 + python/paddle/utils/code_gen/backward.yaml | 13 + tools/infrt/skipped_phi_api.json | 2 +- 19 files changed, 1597 insertions(+), 102 deletions(-) create mode 100644 paddle/phi/kernels/cpu/group_norm_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/group_norm_kernel.cc create mode 100644 paddle/phi/kernels/gpu/group_norm_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/group_norm_kernel.cu create mode 100644 paddle/phi/kernels/gpu/group_norm_utils.h create mode 100644 paddle/phi/kernels/group_norm_grad_kernel.h create mode 100644 paddle/phi/kernels/group_norm_kernel.h create mode 100644 paddle/phi/ops/compat/group_norm_sig.cc diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 4d989ed1f2e..e35598f23e9 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -12,13 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/group_norm_op.h" - #include #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/ternary.h" + namespace paddle { namespace operators { @@ -29,91 +33,6 @@ using DataLayout = framework::DataLayout; class GroupNormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GroupNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "GroupNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Mean"), "Output", "Mean", "GroupNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Variance"), "Output", "Variance", - "GroupNorm"); - - auto x_dim = ctx->GetInputDim("X"); - PADDLE_ENFORCE_GE( - x_dim.size(), 2, - platform::errors::InvalidArgument( - "The Input(X)'s dimension of Op(group_norm) must be " - "greater than 1. But received: %u-D Tensor, which shape is [%s].", - x_dim.size(), x_dim)); - - const std::string data_layout_str = - ctx->Attrs().Get("data_layout"); - const framework::DataLayout data_layout = - framework::StringToDataLayout(data_layout_str); - const int64_t channel_num = - (data_layout == DataLayout::kNCHW ? x_dim[1] : x_dim[x_dim.size() - 1]); - auto batch_size = x_dim[0]; - auto groups = ctx->Attrs().Get("groups"); - PADDLE_ENFORCE_LE( - groups, channel_num, - platform::errors::InvalidArgument( - "The Attr(groups) of Op(group_norm) must be less than or " - "equal to the number of channels. But received: groups " - "is [%s], channels is [%s], the Attr(data_layout) " - "is [%s]. The error may come from wrong data_layout setting.", - groups, channel_num, data_layout_str)); - PADDLE_ENFORCE_GE( - groups, 1, - platform::errors::InvalidArgument( - "The Attr(groups) of Op(group_norm) must be " - "greater than or equal to 1. But received: groups is [%s].", - groups)); - PADDLE_ENFORCE_EQ( - channel_num % groups, 0, - platform::errors::InvalidArgument( - "Expected number of channels in input to be divisible by " - "num_groups, but got input channel is %d and num_groups is %d", - channel_num, groups)); - - if (ctx->HasInput("Scale")) { - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Scale").size(), 1UL, - platform::errors::InvalidArgument( - "The Input(Scale) of Op(group_norm) should be 1-D Tensor. " - "But received: %u-D Tensor, the shape of Input(Scale) is [%s].", - ctx->GetInputDim("Scale").size(), ctx->GetInputDim("Scale"))); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Scale")[0], channel_num, - platform::errors::InvalidArgument( - "The Input(Scale)'s first dimension size of Op(group_norm) must " - "be equal to the number of channels. But received: the " - "Input(Scale)'s first dimension size is [%s], the channels is " - "[%s], the Attr(data_layout) is [%s]. The error may come " - "from wrong data_layout setting.", - ctx->GetInputDim("Scale")[0], channel_num, data_layout_str)); - } - if (ctx->HasInput("Bias")) { - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Bias").size(), 1UL, - platform::errors::InvalidArgument( - "The Input(Bias) of Op(group_norm) should be 1-D Tensor. " - "But received: %u-D Tensor, the shape of Input(Bias) is [%s].", - ctx->GetInputDim("Bias").size(), ctx->GetInputDim("Bias"))); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Bias")[0], channel_num, - platform::errors::InvalidArgument( - "The Input(Bias)'s first dimension size of " - "Op(group_norm) must be equal to the number of channels. " - "But received: the Input(Bias)'s first dimension size is [%s], " - "the channels is [%s], the Attr(data_layout) is [%s]. The " - "error may come from wrong data_layout setting.", - ctx->GetInputDim("Bias")[0], channel_num, data_layout_str)); - } - - ctx->SetOutputDim("Y", ctx->GetInputDim("X")); - ctx->SetOutputDim("Mean", {batch_size, groups}); - ctx->SetOutputDim("Variance", {batch_size, groups}); - ctx->ShareLoD("X", "Y"); - } }; class GroupNormOpMaker : public framework::OpProtoAndCheckerMaker { @@ -252,17 +171,14 @@ class GroupNormOpInferVarType } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(group_norm, GroupNormInferShapeFunctor, + PD_INFER_META(phi::GroupNormInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker, ops::GroupNormOpInferVarType, ops::GroupNormGradMaker, - ops::GroupNormGradMaker); + ops::GroupNormGradMaker, + GroupNormInferShapeFunctor); REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp, ops::GroupNormGradInplaceInferer); -REGISTER_OP_CPU_KERNEL( - group_norm, ops::GroupNormKernel, - ops::GroupNormKernel); -REGISTER_OP_CPU_KERNEL( - group_norm_grad, - ops::GroupNormGradKernel, - ops::GroupNormGradKernel); diff --git a/paddle/fluid/operators/group_norm_op_npu.cc b/paddle/fluid/operators/group_norm_op_npu.cc index dfc509941bc..8217815f9d7 100644 --- a/paddle/fluid/operators/group_norm_op_npu.cc +++ b/paddle/fluid/operators/group_norm_op_npu.cc @@ -14,7 +14,8 @@ limitations under the License. */ #include -#include "paddle/fluid/operators/group_norm_op.h" +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index d84cc9e6d75..a22f720b97e 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/ternary.h" +#include "paddle/phi/common/layout.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/kernels/funcs/common_shape.h" @@ -363,6 +364,122 @@ void GraphSendRecvInferMeta(const MetaTensor& x, } } +void GroupNormInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& bias, + float epsilon, + int groups, + const std::string& data_layout_str, + MetaTensor* y, + MetaTensor* mean, + MetaTensor* variance) { + PADDLE_ENFORCE_NE(y, + nullptr, + phi::errors::InvalidArgument( + "The y in GroupNormInferMeta can't be nullptr.")); + PADDLE_ENFORCE_NE(mean, + nullptr, + phi::errors::InvalidArgument( + "The mean in GroupNormInferMeta can't be nullptr.")); + PADDLE_ENFORCE_NE( + variance, + nullptr, + phi::errors::InvalidArgument( + "The variance in GroupNormInferMeta can't be nullptr.")); + + auto x_dim = x.dims(); + PADDLE_ENFORCE_GE( + x_dim.size(), + 2, + phi::errors::InvalidArgument( + "The Input(X)'s dimension of Op(group_norm) must be " + "greater than 1. But received: %u-D Tensor, which shape is [%s].", + x_dim.size(), + x_dim)); + + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_layout_str); + const int64_t channel_num = + (data_layout == DataLayout::kNCHW ? x_dim[1] : x_dim[x_dim.size() - 1]); + auto batch_size = x_dim[0]; + PADDLE_ENFORCE_LE( + groups, + channel_num, + phi::errors::InvalidArgument( + "The Attr(groups) of Op(group_norm) must be less than or " + "equal to the number of channels. But received: groups " + "is [%s], channels is [%s], the Attr(data_layout) " + "is [%s]. The error may come from wrong data_layout setting.", + groups, + channel_num, + data_layout_str)); + PADDLE_ENFORCE_GE( + groups, + 1, + phi::errors::InvalidArgument( + "The Attr(groups) of Op(group_norm) must be " + "greater than or equal to 1. But received: groups is [%s].", + groups)); + PADDLE_ENFORCE_EQ( + channel_num % groups, + 0, + phi::errors::InvalidArgument( + "Expected number of channels in input to be divisible by " + "num_groups, but got input channel is %d and num_groups is %d", + channel_num, + groups)); + + if (scale) { + PADDLE_ENFORCE_EQ( + scale.dims().size(), + 1UL, + phi::errors::InvalidArgument( + "The Input(Scale) of Op(group_norm) should be 1-D Tensor. " + "But received: %u-D Tensor, the shape of Input(Scale) is [%s].", + scale.dims().size(), + scale.dims())); + PADDLE_ENFORCE_EQ( + scale.dims()[0], + channel_num, + phi::errors::InvalidArgument( + "The Input(Scale)'s first dimension size of Op(group_norm) must " + "be equal to the number of channels. But received: the " + "Input(Scale)'s first dimension size is [%s], the channels is " + "[%s], the Attr(data_layout) is [%s]. The error may come " + "from wrong data_layout setting.", + scale.dims()[0], + channel_num, + data_layout_str)); + } + if (bias) { + PADDLE_ENFORCE_EQ( + bias.dims().size(), + 1UL, + phi::errors::InvalidArgument( + "The Input(Bias) of Op(group_norm) should be 1-D Tensor. " + "But received: %u-D Tensor, the shape of Input(Bias) is [%s].", + bias.dims().size(), + bias.dims())); + PADDLE_ENFORCE_EQ( + bias.dims()[0], + channel_num, + phi::errors::InvalidArgument( + "The Input(Bias)'s first dimension size of " + "Op(group_norm) must be equal to the number of channels. " + "But received: the Input(Bias)'s first dimension size is [%s], " + "the channels is [%s], the Attr(data_layout) is [%s]. The " + "error may come from wrong data_layout setting.", + bias.dims()[0], + channel_num, + data_layout_str)); + } + y->set_dims(x_dim); + y->set_dtype(x.dtype()); + y->share_lod(x); + mean->set_dims({batch_size, groups}); + variance->set_dims({batch_size, groups}); +} + void LayerNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 760011ad829..40461d299fb 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -69,6 +69,16 @@ void GraphSendRecvInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* dst_count); +void GroupNormInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& bias, + float epsilon, + int groups, + const std::string& data_layout, + MetaTensor* y, + MetaTensor* mean, + MetaTensor* variance); + void LayerNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/kernels/cpu/group_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/group_norm_grad_kernel.cc new file mode 100644 index 00000000000..949f9148761 --- /dev/null +++ b/paddle/phi/kernels/cpu/group_norm_grad_kernel.cc @@ -0,0 +1,204 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/group_norm_grad_kernel.h" + +#include +#include +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/extensions.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void GroupNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + const DenseTensor& y, + const DenseTensor& mean, + const DenseTensor& var, + const DenseTensor& d_y, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* d_x, + DenseTensor* d_scale, + DenseTensor* d_bias) { + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_layout_str); + const auto scale_ptr = scale.get_ptr(); + const auto bias_ptr = bias.get_ptr(); + const auto& x_dims = y.dims(); + const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int group_size = C / groups; + + dev_ctx.template Alloc(d_x); + phi::funcs::SetConstant set_zero; + + auto* x_data = y.data(); + auto* d_x_data = d_x->data(); + auto* y_data = d_y.data(); + auto* var_data = var.data(); + T* d_scale_data = nullptr; + if (d_scale) { + dev_ctx.template Alloc(d_scale); + set_zero(dev_ctx, d_scale, static_cast(0)); + d_scale_data = d_scale->data(); + } + T* d_bias_data = nullptr; + if (d_bias) { + dev_ctx.template Alloc(d_bias); + set_zero(dev_ctx, d_bias, static_cast(0)); + d_bias_data = d_bias->data(); + } + + const T* scale_data = nullptr; + if (scale_ptr) scale_data = scale_ptr->data(); + const T* bias_data = nullptr; + if (bias_ptr) bias_data = bias_ptr->data(); + + int imsize = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < x_dims.size(); ++i) { + imsize *= x_dims[i]; + } + } else { + for (int i = 1; i < x_dims.size() - 1; ++i) { + imsize *= x_dims[i]; + } + } + auto* iter_x_data = x_data; + auto* iter_d_x_data = d_x_data; + auto* iter_y_data = y_data; + for (int bid = 0; bid < x_dims[0]; bid++) { + for (int gid = 0; gid < groups; gid++) { + T x_var = var_data[bid * groups + gid]; + T var_inv = 1.0 / sqrt(x_var + epsilon); + int number = std::min(group_size, static_cast(C - gid * group_size)); + T number_inv = 1.0 / (number * imsize); + auto* tmp_x = iter_x_data; + auto* tmp_y = iter_y_data; + auto* tmp_d_x = iter_d_x_data; + auto* x_src_data = iter_x_data; + auto* y_src_data = iter_y_data; + auto* iter_x_data_backup = iter_x_data; + auto* iter_y_data_backup = iter_y_data; + auto* iter_d_x_data_backup = iter_d_x_data; + T dp_scale = 0, dp_bias = 0; + + if (data_layout == DataLayout::kNCHW) { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; + imid++, iter_x_data++, iter_y_data++) { + T val = iter_x_data[0]; + if (bias_data) val -= bias_data[gid * group_size + cid]; + T dval = iter_y_data[0]; + dp_scale += val * dval; + if (scale_data) + dp_bias += dval * scale_data[gid * group_size + cid]; + + if (scale_data && scale_data[gid * group_size + cid] != 0) + val /= scale_data[gid * group_size + cid]; + if (d_bias_data) d_bias_data[gid * group_size + cid] += dval; + if (d_scale_data) + d_scale_data[gid * group_size + cid] += val * dval; + } + } + + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; + imid++, iter_d_x_data++, tmp_x++, tmp_y++) { + T v_y = tmp_x[0]; + T dly = tmp_y[0]; + T dss = dp_scale; + T dbs = dp_bias; + T v_scale = 1., v_bias = 0.; + if (scale_data) v_scale = scale_data[gid * group_size + cid]; + if (bias_data) v_bias = bias_data[gid * group_size + cid]; + v_y -= v_bias; + if (v_scale != 0) v_y /= v_scale; + iter_d_x_data[0] = + (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * + var_inv; + } + } + } else { + for (int cid = 0; cid < number; cid++) { + iter_x_data = x_src_data + cid; + iter_y_data = y_src_data + cid; + for (int imid = 0; imid < imsize; + imid++, iter_x_data += C, iter_y_data += C) { + T val = iter_x_data[0]; + if (bias_data) val -= bias_data[gid * group_size + cid]; + T dval = iter_y_data[0]; + dp_scale += val * dval; + if (scale_data) + dp_bias += dval * scale_data[gid * group_size + cid]; + + if (scale_data && scale_data[gid * group_size + cid] != 0) + val /= scale_data[gid * group_size + cid]; + if (d_bias_data) d_bias_data[gid * group_size + cid] += dval; + if (d_scale_data) + d_scale_data[gid * group_size + cid] += val * dval; + } + } + + for (int cid = 0; cid < number; cid++) { + tmp_x = x_src_data + cid; + tmp_y = y_src_data + cid; + iter_d_x_data = tmp_d_x + cid; + for (int imid = 0; imid < imsize; + imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) { + T v_y = tmp_x[0]; + T dly = tmp_y[0]; + T dss = dp_scale; + T dbs = dp_bias; + T v_scale = 1.0, v_bias = 0.; + if (scale_data) v_scale = scale_data[gid * group_size + cid]; + if (bias_data) v_bias = bias_data[gid * group_size + cid]; + v_y -= v_bias; + if (v_scale != 0) v_y /= v_scale; + iter_d_x_data[0] = + (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * + var_inv; + } + } + iter_x_data = iter_x_data_backup + group_size; + iter_y_data = iter_y_data_backup + group_size; + iter_d_x_data = iter_d_x_data_backup + group_size; + } + } + if (data_layout == DataLayout::kNHWC) { + iter_x_data = x_data + (bid + 1) * C * imsize; + iter_d_x_data = d_x_data + (bid + 1) * C * imsize; + iter_y_data = y_data + (bid + 1) * C * imsize; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + group_norm_grad, CPU, ALL_LAYOUT, phi::GroupNormGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/cpu/group_norm_kernel.cc b/paddle/phi/kernels/cpu/group_norm_kernel.cc new file mode 100644 index 00000000000..12aedf4cb44 --- /dev/null +++ b/paddle/phi/kernels/cpu/group_norm_kernel.cc @@ -0,0 +1,210 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/group_norm_kernel.h" + +#include +#include +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/extensions.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void GroupNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_layout_str); + const auto scale_ptr = scale.get_ptr(); + const auto bias_ptr = bias.get_ptr(); + + const auto x_dims = x.dims(); + const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int group_size = C / groups; + + dev_ctx.template Alloc(y); + dev_ctx.template Alloc(mean); + dev_ctx.template Alloc(var); + + auto* x_data = x.data(); + auto* y_data = y->data(); + auto* mean_data = mean->data(); + auto* var_data = var->data(); + + const T* scale_data = nullptr; + if (scale_ptr) scale_data = scale_ptr->data(); + const T* bias_data = nullptr; + if (bias_ptr) bias_data = bias_ptr->data(); + + int imsize = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < x_dims.size(); ++i) { + imsize *= x_dims[i]; + } + } else { + for (int i = 1; i < x_dims.size() - 1; ++i) { + imsize *= x_dims[i]; + } + } + auto* iter_x_data = x_data; + auto* iter_y_data = y_data; + for (int bid = 0; bid < x_dims[0]; bid++) { + for (int gid = 0; gid < groups; gid++) { + const int64_t M = 8; + std::array x_mean_arr; + std::array x_var_arr; + std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0)); + std::fill(x_var_arr.begin(), x_var_arr.end(), T(0)); + T x_mean = 0, x_var = 0; + int number = std::min(group_size, static_cast(C - gid * group_size)); + auto* tmp_x = iter_x_data; + auto* x_src_data = iter_x_data; + auto* tmp_y = iter_y_data; + auto* y_src_data = iter_y_data; + + if (data_layout == DataLayout::kNCHW) { + for (int cid = 0; cid < number; cid++) { + int imid; + for (imid = 0; imid < imsize - (imsize % M); + imid += M, iter_x_data += M) { + // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used + // in template class/function, before we complete high + // performance cpu vector extension, temporarily unrolling + // loop to get high precision and performance + x_mean_arr[0] += iter_x_data[0]; + x_var_arr[0] += iter_x_data[0] * iter_x_data[0]; + x_mean_arr[1] += iter_x_data[1]; + x_var_arr[1] += iter_x_data[1] * iter_x_data[1]; + x_mean_arr[2] += iter_x_data[2]; + x_var_arr[2] += iter_x_data[2] * iter_x_data[2]; + x_mean_arr[3] += iter_x_data[3]; + x_var_arr[3] += iter_x_data[3] * iter_x_data[3]; + x_mean_arr[4] += iter_x_data[4]; + x_var_arr[4] += iter_x_data[4] * iter_x_data[4]; + x_mean_arr[5] += iter_x_data[5]; + x_var_arr[5] += iter_x_data[5] * iter_x_data[5]; + x_mean_arr[6] += iter_x_data[6]; + x_var_arr[6] += iter_x_data[6] * iter_x_data[6]; + x_mean_arr[7] += iter_x_data[7]; + x_var_arr[7] += iter_x_data[7] * iter_x_data[7]; + } + x_mean = + std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean); + x_var = std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var); + std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0)); + std::fill(x_var_arr.begin(), x_var_arr.end(), T(0)); + for (; imid < imsize; imid++, iter_x_data++) { + x_mean += iter_x_data[0]; + x_var += iter_x_data[0] * iter_x_data[0]; + } + } + } else { + for (int cid = 0; cid < number; cid++) { + iter_x_data = tmp_x + cid; + int imid; + for (imid = 0; imid < imsize - (imsize % M); + imid += M, iter_x_data += M * C) { + // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used + // in template class/function, before we complete high + // performance cpu vector extension, temporarily unrolling + // loop to get high precision and performance + x_mean_arr[0] += iter_x_data[0 * C]; + x_var_arr[0] += iter_x_data[0 * C] * iter_x_data[0 * C]; + x_mean_arr[1] += iter_x_data[1 * C]; + x_var_arr[1] += iter_x_data[1 * C] * iter_x_data[1 * C]; + x_mean_arr[2] += iter_x_data[2 * C]; + x_var_arr[2] += iter_x_data[2 * C] * iter_x_data[2 * C]; + x_mean_arr[3] += iter_x_data[3 * C]; + x_var_arr[3] += iter_x_data[3 * C] * iter_x_data[3 * C]; + x_mean_arr[4] += iter_x_data[4 * C]; + x_var_arr[4] += iter_x_data[4 * C] * iter_x_data[4 * C]; + x_mean_arr[5] += iter_x_data[5 * C]; + x_var_arr[5] += iter_x_data[5 * C] * iter_x_data[5 * C]; + x_mean_arr[6] += iter_x_data[6 * C]; + x_var_arr[6] += iter_x_data[6 * C] * iter_x_data[6 * C]; + x_mean_arr[7] += iter_x_data[7 * C]; + x_var_arr[7] += iter_x_data[7 * C] * iter_x_data[7 * C]; + } + x_mean = + std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean); + x_var = std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var); + std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0)); + std::fill(x_var_arr.begin(), x_var_arr.end(), T(0)); + for (; imid < imsize; imid++, iter_x_data += C) { + x_mean += iter_x_data[0]; + x_var += iter_x_data[0] * iter_x_data[0]; + } + } + iter_x_data = tmp_x + group_size; + } + + x_mean /= number * imsize; + x_var /= number * imsize; + x_var = std::max(x_var - x_mean * x_mean, T(0)); + T var_inv = T(1) / std::sqrt(x_var + epsilon); + mean_data[bid * groups + gid] = x_mean; + var_data[bid * groups + gid] = x_var; + + if (data_layout == DataLayout::kNCHW) { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; imid++, tmp_x++, iter_y_data++) { + T val = (tmp_x[0] - x_mean) * var_inv; + if (scale_data) val *= scale_data[gid * group_size + cid]; + if (bias_data) val += bias_data[gid * group_size + cid]; + iter_y_data[0] = val; + } + } + } else { + for (int cid = 0; cid < number; cid++) { + tmp_x = x_src_data + cid; + iter_y_data = y_src_data + cid; + for (int imid = 0; imid < imsize; + imid++, tmp_x += C, iter_y_data += C) { + T val = (tmp_x[0] - x_mean) * var_inv; + if (scale_data) val *= scale_data[gid * group_size + cid]; + if (bias_data) val += bias_data[gid * group_size + cid]; + iter_y_data[0] = val; + } + } + iter_y_data = tmp_y + group_size; + } + } + if (data_layout == DataLayout::kNHWC) { + iter_x_data = x_data + (bid + 1) * C * imsize; + iter_y_data = y_data + (bid + 1) * C * imsize; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + group_norm, CPU, ALL_LAYOUT, phi::GroupNormKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu new file mode 100644 index 00000000000..8af66fe0f29 --- /dev/null +++ b/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu @@ -0,0 +1,452 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpu/group_norm_utils.h" +#include "paddle/phi/kernels/group_norm_grad_kernel.h" + +namespace phi { + +template +__global__ void GroupNormBackwardGetMeanAndVar(const T* x, + const T* scale, + const T* bias, + const T* d_y, + int N, + int C, + int W, + int imsize, + int groups, + int group_size, + T epsilon, + T* d_mean, + T* d_var, + T* d_scale, + T* d_bias) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int H = imsize / W; + int number = min(group_size, static_cast(C - gid * group_size)); + int ccid = gid * group_size + cid; + if (ccid >= C) return; + T x_scale = (flags & kHasScale) ? scale[ccid] : 1; + T x_bias = (flags & kHasBias) ? bias[ccid] : 0; + T x_scale_inv = 0; + if (x_scale != 0) x_scale_inv = 1.0 / x_scale; + T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0; + + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + T val, dval; + + int hid = imid / W; + int wid = imid % W; + val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias; + dval = d_y[(bid * H + hid) * W * C + wid * C + ccid]; + + d_var_data += val * dval; + d_mean_data += dval * x_scale; + + val = val * x_scale_inv; + d_bias_data += dval; + d_scale_data += val * dval; + } + CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data); + CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data); + if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data); + if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data); +} + +template +__global__ void GroupNormBackward(const T* x, + const T* d_y, + const T* scale, + const T* bias, + const T* var, + const T* d_mean, + const T* d_var, + int N, + int C, + int W, + int imsize, + int groups, + int group_size, + T epsilon, + T* d_x) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int H = imsize / W; + int number = min(group_size, static_cast(C - gid * group_size)); + int ccid = gid * group_size + cid; + if (ccid >= C) return; + T x_var = var[bid * groups + gid]; + T d_x_mean = d_mean[bid * groups + gid]; + T d_x_var = d_var[bid * groups + gid]; + + T x_var_inv = 1.0 / sqrt(x_var + epsilon); + T number_inv = 1.0 / (number * imsize); + + T x_scale = (flags & kHasScale) ? scale[ccid] : 1; + T x_bias = (flags & kHasBias) ? bias[ccid] : 0; + T x_scale_inv = 0; + if (x_scale != 0) x_scale_inv = 1.0 / x_scale; + + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + int hid = imid / W; + int wid = imid % W; + T tmp = x[(bid * H + hid) * W * C + wid * C + ccid]; + T v_y = (tmp - x_bias) * x_scale_inv; + T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid]; + d_x[(bid * H + hid) * W * C + wid * C + ccid] = + x_var_inv * + (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean); + } +} + +template +__global__ void ScalarGetDsDbCUDAKernel( + int imsize, const T* x, const T* dy, T* ds, T* db) { + const int nc = blockIdx.x; + T ds_sum = 0; + T db_sum = 0; + for (int i = threadIdx.x; i < imsize; i += blockDim.x) { + const int index = nc * imsize + i; + ds_sum += dy[index] * x[index]; + db_sum += dy[index]; + } + ReduceMeanAndVar(db, ds, db_sum, ds_sum, 1); +} + +template +__global__ void GetScaleBiasGradientCUDAKernel(int N, + int C, + int group, + T epsilon, + const T* mean, + const T* var, + const T* ds, + const T* db, + T* d_scale, + T* d_bias) { + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < C) { + const int G = group; + const int D = C / G; + T sum1 = 0; + T sum2 = 0; + for (int n = 0; n < N; ++n) { + const int nc = n * C + c; + const int ng = n * G + c / D; + sum1 += (d_scale == nullptr) + ? T(0) + : ((ds[nc] - db[nc] * static_cast(mean[ng])) * + static_cast(rsqrt(var[ng] + epsilon))); + sum2 += (d_bias == nullptr) ? T(0) : db[nc]; + } + if (d_scale != nullptr) { + d_scale[c] = sum1; + } + if (d_bias != nullptr) { + d_bias[c] = sum2; + } + } +} + +template +__global__ void GetBackwardParamsCUDAKernel(int imsize, + int groups, + int group_size, + T epsilon, + const T* mean, + const T* var, + const T* scale, + const T* ds, + const T* db, + T* p1, + T* p2, + T* p3) { + const int n = blockIdx.x; + const int g = blockIdx.y; + const int ng = n * groups + g; + T sum1 = 0; + T sum2 = 0; + T var_inv = rsqrt(var[ng] + epsilon); + for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) { + const int64_t index = ng * group_size + i; + const int64_t c = g * group_size + i; + const T scale_v = scale == nullptr ? T(1) : static_cast(scale[c]); + sum1 += ds[index] * scale_v; + sum2 += db[index] * scale_v; + const T scale_c = scale == nullptr ? T(0) : static_cast(scale[c]); + p1[index] = scale_c * var_inv; + } + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage ds_storage; + __shared__ typename BlockReduce::TempStorage db_storage; + sum1 = BlockReduce(ds_storage).Reduce(sum1, cub::Sum()); + sum2 = BlockReduce(db_storage).Reduce(sum2, cub::Sum()); + + if (threadIdx.x == 0) { + const T s = T(1) / static_cast(group_size * imsize); + const T x = (sum2 * static_cast(mean[ng]) - sum1) * + static_cast(var_inv) * static_cast(var_inv) * + static_cast(var_inv) * s; + p2[ng] = x; + p3[ng] = -x * static_cast(mean[ng]) - sum2 * static_cast(var_inv) * s; + } +} + +template +__global__ void GetXGradientCUDAKernel(int imsize, + int C, + int group_size, + int groups, + T* p1, + T* p2, + T* p3, + const T* x, + const T* dy, + T* dx) { + int cid = blockIdx.x; + int gid = blockIdx.y; + int bid = blockIdx.z; + int ccid = bid * C + gid * group_size + cid; + int ng = bid * groups + gid; + int nc = gid * group_size + cid; + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + int index = (bid * C + nc) * imsize + imid; + dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng]; + } +} + +template +void GroupNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + const DenseTensor& y, + const DenseTensor& mean, + const DenseTensor& var, + const DenseTensor& d_y, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* d_x, + DenseTensor* d_scale, + DenseTensor* d_bias) { + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_layout_str); + const auto scale_ptr = scale.get_ptr(); + const auto bias_ptr = bias.get_ptr(); + + const auto& x_dims = x.dims(); + const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int group_size = C / groups; + const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] + : x_dims[x_dims.size() - 2]); + + dev_ctx.template Alloc(d_x); + phi::funcs::SetConstant set_zero; + + DenseTensor ds, db; + ds.Resize({x_dims[0], C}); + T* ds_data = dev_ctx.template Alloc(&ds); + db.Resize({x_dims[0], C}); + T* db_data = dev_ctx.template Alloc(&db); + + auto* y_data = y.data(); + auto* x_data = x.data(); + T* d_x_data = nullptr; + if (d_x) d_x_data = d_x->data(); + auto* dy_data = d_y.data(); + auto* var_data = var.data(); + auto* mean_data = mean.data(); + T* d_scale_data = nullptr; + if (d_scale) { + dev_ctx.template Alloc(d_scale); + d_scale_data = d_scale->data(); + } + T* d_bias_data = nullptr; + if (d_bias) { + dev_ctx.template Alloc(d_bias); + d_bias_data = d_bias->data(); + } + + const T* scale_data = nullptr; + if (scale_ptr) scale_data = scale_ptr->data(); + const T* bias_data = nullptr; + if (bias_ptr) bias_data = bias_ptr->data(); + + int imsize = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < x_dims.size(); ++i) { + imsize *= x_dims[i]; + } + } else { + for (int i = 1; i < x_dims.size() - 1; ++i) { + imsize *= x_dims[i]; + } + } + +#ifdef __HIPCC__ + int block_size = std::max(std::min(256, imsize), 64); + const int block_dims = 256; +#else + int block_size = std::min(1024, imsize); + const int block_dims = 1024; +#endif + dim3 grid(group_size, groups, x_dims[0]); + dim3 threads(block_size, 1, 1); + int flags = + (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; + if (data_layout == DataLayout::kNCHW) { + const int max_num_threads = 1024; + int max_block_size = std::min(imsize, max_num_threads); + int block_size_nchw = 1; + while (block_size_nchw < max_block_size) { + block_size_nchw *= 2; + } + block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize); + dim3 blocks(block_size_nchw); + ScalarGetDsDbCUDAKernel<<>>( + imsize, x_data, dy_data, ds_data, db_data); + + if (d_scale || d_bias) { + const int block = 256; + GetScaleBiasGradientCUDAKernel + <<<(C + block - 1) / block, block, 0, dev_ctx.stream()>>>( + x_dims[0], + C, + groups, + epsilon, + mean_data, + var_data, + ds_data, + db_data, + d_scale_data, + d_bias_data); + } + + if (d_x_data != nullptr) { + // p1 * dy + p2 * x + p3, + // p1, p2, p3 represent the reverse calculation of temporary variables + // p1 = scale * var_inv + // p2 = (db * scale * mean - ds * scale) * pow(var_inv, 3) * (1/n) + // p3 = -p2 * mean[ng] - db * scale * var_inv * (1/n); + DenseTensor p1, p2, p3; + p1.Resize({x_dims[0] * C}); + T* p1_data = dev_ctx.template Alloc(&p1); + p2.Resize({x_dims[0], groups}); + T* p2_data = dev_ctx.template Alloc(&p2); + p3.Resize({x_dims[0], groups}); + T* p3_data = dev_ctx.template Alloc(&p3); + + GetBackwardParamsCUDAKernel + <<>>( + imsize, + groups, + group_size, + epsilon, + mean_data, + var_data, + scale_data, + ds_data, + db_data, + p1_data, + p2_data, + p3_data); + GetXGradientCUDAKernel + <<>>(imsize, + C, + group_size, + groups, + p1_data, + p2_data, + p3_data, + x_data, + dy_data, + d_x_data); + } + } else { + if (d_scale) { + set_zero(dev_ctx, d_scale, static_cast(0)); + } + if (d_bias) { + set_zero(dev_ctx, d_bias, static_cast(0)); + } + + DenseTensor temp_var; + temp_var.Resize(var.dims()); + dev_ctx.template Alloc(&temp_var); + set_zero(dev_ctx, &temp_var, static_cast(0)); + T* temp_var_data = temp_var.data(); + + DenseTensor temp_mean; + temp_mean.Resize(var.dims()); + dev_ctx.template Alloc(&temp_mean); + set_zero(dev_ctx, &temp_mean, static_cast(0)); + T* temp_mean_data = temp_mean.data(); + + int flags = + (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; + UNROLL_ALL_CASES(flags, + GroupNormBackwardGetMeanAndVar, + y_data, + scale_data, + bias_data, + dy_data, + x_dims[0], + C, + W, + imsize, + groups, + group_size, + epsilon, + temp_mean_data, + temp_var_data, + d_scale_data, + d_bias_data); + if (d_x_data != nullptr) { + UNROLL_ALL_CASES(flags, + GroupNormBackward, + y_data, + dy_data, + scale_data, + bias_data, + var_data, + temp_mean_data, + temp_var_data, + x_dims[0], + C, + W, + imsize, + groups, + group_size, + epsilon, + d_x_data); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + group_norm_grad, GPU, ALL_LAYOUT, phi::GroupNormGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/gpu/group_norm_kernel.cu b/paddle/phi/kernels/gpu/group_norm_kernel.cu new file mode 100644 index 00000000000..127677233b8 --- /dev/null +++ b/paddle/phi/kernels/gpu/group_norm_kernel.cu @@ -0,0 +1,233 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpu/group_norm_utils.h" +#include "paddle/phi/kernels/group_norm_kernel.h" + +namespace phi { + +template +__global__ void GroupNormForwardGetMeanAndVar(const T* x, + int N, + int C, + int W, + int imsize, + int groups, + int group_size, + T* mean, + T* var) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int H = imsize / W; + int number = min(group_size, static_cast(C - gid * group_size)); + int ccid = gid * group_size + cid; + if (ccid >= C) return; + T x_mean = 0, x_var = 0; + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + T val; + int hid = imid / W; + int wid = imid % W; + val = x[(bid * H + hid) * W * C + wid * C + ccid]; + + x_mean += val; + x_var += val * val; + } + x_mean /= number * imsize; + x_var /= number * imsize; + CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean); + CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var); +} + +template +__global__ void GroupNormForward(const T* x, + const T* mean, + const T* var, + const T* scale, + const T* bias, + int N, + int C, + int W, + int imsize, + int groups, + int group_size, + T epsilon, + T* y, + T* real_var, + const DataLayout data_layout) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int H = imsize / W; + int ccid = gid * group_size + cid; + if (ccid >= C) return; + auto ng = bid * groups + gid; + T x_mean = mean[ng]; + T x_var = var[ng]; + x_var = x_var - x_mean * x_mean; + T var_inv = rsqrt(x_var + epsilon); + if (cid == 0 && threadIdx.x == 0) { + real_var[ng] = x_var; + } + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + T val; + int hid, wid; + int index = (bid * C + ccid) * imsize + imid; + if (data_layout == DataLayout::kNCHW) { + val = x[index]; + } else { + hid = imid / W; + wid = imid % W; + val = x[(bid * H + hid) * W * C + wid * C + ccid]; + } + val = (val - x_mean) * var_inv; + if (flags & kHasScale) { + val *= scale[ccid]; + } + if (flags & kHasBias) { + val += bias[ccid]; + } + if (data_layout == DataLayout::kNCHW) { + y[index] = val; + } else { + y[(bid * H + hid) * W * C + wid * C + ccid] = val; + } + } +} + +template +void GroupNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_layout_str); + const auto scale_ptr = scale.get_ptr(); + const auto bias_ptr = bias.get_ptr(); + + const auto x_dims = x.dims(); + const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int group_size = C / groups; + + const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] + : x_dims[x_dims.size() - 2]); + + dev_ctx.template Alloc(y); + dev_ctx.template Alloc(mean); + dev_ctx.template Alloc(var); + phi::funcs::SetConstant set_zero; + DenseTensor temp_var; + temp_var.Resize(var->dims()); + dev_ctx.template Alloc(&temp_var); + auto* x_data = x.data(); + auto* y_data = y->data(); + auto* mean_data = mean->data(); + auto* var_data = var->data(); + auto* temp_var_data = temp_var.data(); + + const T* scale_data = nullptr; + if (scale_ptr) scale_data = scale_ptr->data(); + const T* bias_data = nullptr; + if (bias_ptr) bias_data = bias_ptr->data(); + + int imsize = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < x_dims.size(); ++i) { + imsize *= x_dims[i]; + } + } else { + for (int i = 1; i < x_dims.size() - 1; ++i) { + imsize *= x_dims[i]; + } + } + +#ifdef __HIPCC__ + int block_size = std::max(std::min(256, imsize), 64); +#else + int block_size = std::min(1024, imsize); +#endif + + dim3 grid(group_size, groups, x_dims[0]); + dim3 threads(block_size, 1, 1); + if (data_layout == DataLayout::kNCHW) { + using AccT = typename kps::details::MPTypeTrait::Type; + constexpr int vec_size = sizeof(float4) / sizeof(T); + int size = group_size * imsize; + const int max_num_threads = 1024; + int max_block_size = std::min(size / vec_size, max_num_threads); + int block_size_nchw = 1; + while (block_size_nchw < max_block_size) { + block_size_nchw *= 2; + } + block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize); + dim3 grids(x_dims[0] * groups); + dim3 blocks(block_size_nchw); + if (size < vec_size * block_size_nchw) { + ScalarGetMeanAndVarNCHW<<>>( + x_data, mean_data, temp_var_data, size); + } else { + VectorizedGetMeanAndVarNCHW + <<>>( + x_data, mean_data, temp_var_data, size); + } + } else { + set_zero(dev_ctx, mean, static_cast(0)); + set_zero(dev_ctx, &temp_var, static_cast(0)); + GroupNormForwardGetMeanAndVar + <<>>(x_data, + x_dims[0], + C, + W, + imsize, + groups, + group_size, + mean_data, + temp_var_data); + } + int flags = + (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; + UNROLL_ALL_CASES(flags, + GroupNormForward, + x_data, + mean_data, + temp_var_data, + scale_data, + bias_data, + x_dims[0], + C, + W, + imsize, + groups, + group_size, + epsilon, + y_data, + var_data, + data_layout); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + group_norm, GPU, ALL_LAYOUT, phi::GroupNormKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/group_norm_utils.h b/paddle/phi/kernels/gpu/group_norm_utils.h new file mode 100644 index 00000000000..6af7b96ca21 --- /dev/null +++ b/paddle/phi/kernels/gpu/group_norm_utils.h @@ -0,0 +1,174 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/primitive/kernel_primitives.h" + +namespace phi { + +enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 }; +#define ALIGN_BYTES 16 + +#define CHECK_CASE(i, flags, kernel_name, ...) \ + if (i == flags) { \ + kernel_name<<>>(__VA_ARGS__); \ + } + +// 0 for no scale, no bias +// 1 for has scale, no bias +// 2 for no scale, has bias +// 3 for has scale, has bias +#define UNROLL_ALL_CASES(flags, kernel_name, ...) \ + CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \ + CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \ + CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \ + CHECK_CASE(3, flags, kernel_name, __VA_ARGS__) + +template +__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { + typedef cub::WarpReduce WarpReduce; + typename WarpReduce::TempStorage temp_storage; + value = WarpReduce(temp_storage).Sum(value); + if (cub::LaneId() == 0) paddle::platform::CudaAtomicAdd(sum, value); +} + +template +__device__ __forceinline__ void ThreadReduce(phi::Array arrs, + int size, + const int offset, + AccT* out_mean, + AccT* out_var) { + const T* x = arrs[0]; + const T* y; + if (Num == 2) { + y = arrs[1]; + } + using VecT = kps::details::VectorType; + int tid = threadIdx.x; + if (offset > 0) { + x -= offset; + if (Num == 2) { + y -= offset; + } + size += offset; + if (tid >= offset) { + if (Num == 1) { + *out_mean += x[tid]; + *out_var += x[tid] * x[tid]; + } else if (Num == 2) { + *out_mean += y[tid]; + *out_var += y[tid] * x[tid]; + } + } + size -= blockDim.x; + x += blockDim.x; + if (Num == 2) { + y += blockDim.x; + } + } + int remain = size % (VecSize * blockDim.x); + + T ins_x[VecSize]; + T ins_y[VecSize]; + VecT* ins_vec_x = reinterpret_cast(&ins_x); + VecT* ins_vec_y = reinterpret_cast(&ins_y); + + // vector part + for (; VecSize * tid < (size - remain); tid += blockDim.x) { + *ins_vec_x = reinterpret_cast(x)[tid]; + if (Num == 2) { + *ins_vec_y = reinterpret_cast(y)[tid]; + } + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + if (Num == 1) { + *out_mean += ins_x[i]; + *out_var += ins_x[i] * ins_x[i]; + } else if (Num == 2) { + *out_mean += ins_y[i]; + *out_var += ins_y[i] * ins_x[i]; + } + } + } + + // scalar part + tid = size - remain + threadIdx.x; + for (; tid < size; tid += blockDim.x) { + if (Num == 1) { + *out_mean += x[tid]; + *out_var += x[tid] * x[tid]; + } else if (Num == 2) { + *out_mean += y[tid]; + *out_var += y[tid] * x[tid]; + } + } +} + +template +__device__ __forceinline__ void ReduceMeanAndVar( + T* mean, T* var, T x_mean, T x_var, int size) { + const int nc = blockIdx.x; + x_mean = kps::details::BlockXReduce>( + x_mean, kps::AddFunctor()); + x_var = kps::details::BlockXReduce>( + x_var, kps::AddFunctor()); + __syncthreads(); + if (threadIdx.x == 0) { + mean[nc] = static_cast(x_mean / size); + var[nc] = static_cast(x_var / size); + } +} + +template +__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) { + int i = blockIdx.x; + T x_mean = 0, x_var = 0; + for (int j = threadIdx.x; j < size; j += blockDim.x) { + T val; + val = x[i * size + j]; + x_mean += val; + x_var += val * val; + } + ReduceMeanAndVar(mean, var, x_mean, x_var, size); +} + +template +__global__ void VectorizedGetMeanAndVarNCHW(const T* x, + T* mean, + T* var, + int size) { + int i = blockIdx.x; + AccT x_mean = static_cast(0); + AccT x_var = static_cast(0); + x += i * size; + const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); + phi::Array ins; + ins[0] = x; + ThreadReduce(ins, size, input_offset, &x_mean, &x_var); + ReduceMeanAndVar(mean, var, x_mean, x_var, size); +} + +} // namespace phi diff --git a/paddle/phi/kernels/group_norm_grad_kernel.h b/paddle/phi/kernels/group_norm_grad_kernel.h new file mode 100644 index 00000000000..cc404f02132 --- /dev/null +++ b/paddle/phi/kernels/group_norm_grad_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GroupNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + const DenseTensor& y, + const DenseTensor& mean, + const DenseTensor& variance, + const DenseTensor& d_y, + float epsilon, + int groups, + const std::string& data_layout, + DenseTensor* d_x, + DenseTensor* d_scale, + DenseTensor* d_bias); + +} // namespace phi diff --git a/paddle/phi/kernels/group_norm_kernel.h b/paddle/phi/kernels/group_norm_kernel.h new file mode 100644 index 00000000000..36bf7125ec1 --- /dev/null +++ b/paddle/phi/kernels/group_norm_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GroupNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon, + int groups, + const std::string& data_layout, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* variance); + +} // namespace phi diff --git a/paddle/phi/ops/compat/group_norm_sig.cc b/paddle/phi/ops/compat/group_norm_sig.cc new file mode 100644 index 00000000000..d5a9cad97a2 --- /dev/null +++ b/paddle/phi/ops/compat/group_norm_sig.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GroupNormOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("group_norm", + {"X", "Scale", "Bias"}, + {"epsilon", "groups", "data_layout"}, + {"Y", "Mean", "Variance"}); +} + +KernelSignature GroupNormGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "group_norm_grad", + {"X", "Scale", "Bias", "Y", "Mean", "Variance", "Y@GRAD"}, + {"epsilon", "groups", "data_layout"}, + {"X@GRAD", "Scale@GRAD", "Bias@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(group_norm, phi::GroupNormOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(group_norm_grad, + phi::GroupNormGradOpArgumentMapping); diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 26bda1a34ef..0f250fbd870 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1144,8 +1144,8 @@ class InstanceNorm(layers.Layer): def forward(self, input): if in_dygraph_mode(): - out, _, _, = _C_ops.final_state_instance_norm( - input, self.scale, self.bias, self._epsilon) + out = _C_ops.final_state_instance_norm(input, self.scale, self.bias, + self._epsilon) return out if _in_legacy_dygraph(): out, _, _ = _C_ops.instance_norm(input, self.scale, self.bias, @@ -3031,8 +3031,14 @@ class GroupNorm(layers.Layer): dtype=self._dtype, stop_gradient=True) variance_out = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True) + if in_dygraph_mode(): + out = _C_ops.final_state_group_norm(input, self.weight, self.bias, + self._epsilon, self._groups, + "NCHW") - if _non_static_mode(): + return dygraph_utils._append_activation_in_dygraph(out, self._act) + + elif _in_legacy_dygraph(): attrs = ('epsilon', self._epsilon, 'groups', self._groups) out, _, _ = _C_ops.group_norm(input, self.weight, self.bias, mean_out, variance_out, *attrs) diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op.py b/python/paddle/fluid/tests/unittests/test_group_norm_op.py index 94793ad85cf..179b197cf62 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op.py @@ -20,7 +20,7 @@ from operator import mul import paddle.fluid.core as core import paddle.fluid as fluid from op_test import OpTest, skip_check_grad_ci - +from paddle.fluid.framework import _test_eager_guard from testsuite import create_op @@ -301,5 +301,30 @@ class TestGroupNormException(unittest.TestCase): self.assertRaises(ValueError, attr_data_format) +class TestGroupNormEager(unittest.TestCase): + + def test_dygraph_final_state_api(self): + self.dtype = np.float64 + self.shape = (8, 32, 32) + input = np.random.random(self.shape).astype(self.dtype) + + with fluid.dygraph.guard(): + tensor_1 = fluid.dygraph.to_variable(input) + tensor_1.stop_gradient = False + groupNorm = fluid.dygraph.nn.GroupNorm(channels=32, groups=4) + ret1 = groupNorm(tensor_1) + ret1.backward() + with _test_eager_guard(): + tensor_eager_1 = fluid.dygraph.to_variable(input) + tensor_eager_1.stop_gradient = False + groupNorm_eager = fluid.dygraph.nn.GroupNorm(channels=32, + groups=4) + ret2 = groupNorm_eager(tensor_eager_1) + ret2.backward() + self.assertEqual(( + tensor_1.grad.numpy() == tensor_eager_1.grad.numpy()).all(), + True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py index c6bc44ebd2f..42f97585172 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py @@ -22,6 +22,7 @@ from op_test import OpTest, _set_use_system_allocator from paddle.fluid.framework import grad_var_name import paddle.fluid as fluid from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard import paddle @@ -124,6 +125,10 @@ class TestDygraphGroupNormv2(unittest.TestCase): y2 = compute_v2(x) self.assertTrue(np.allclose(y1, y2, atol=1e-5)) + def test_eager_api(self): + with _test_eager_guard(): + self.test_dygraph() + class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase): @@ -154,6 +159,10 @@ class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase): self.assertTrue(np.allclose(result1, expect_res1, atol=1e-5)) self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5)) + def test_eager_api(self): + with _test_eager_guard(): + self.test_numerical_accuracy() + class TestGroupNormDimException(unittest.TestCase): diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 7bc9f105cac..e40731b828d 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -413,7 +413,7 @@ def instance_norm(x, """ if in_dygraph_mode(): - out, _, _, = _C_ops.final_state_instance_norm(x, weight, bias, eps) + out = _C_ops.final_state_instance_norm(x, weight, bias, eps) return out if _in_legacy_dygraph(): out, _, _ = _C_ops.instance_norm(x, weight, bias, "epsilon", eps, diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 8ed4832a8f7..fd000567c50 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -933,6 +933,17 @@ kernel : func : greater_than +- api : group_norm + args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int groups, str data_layout) + output : Tensor(y), Tensor(mean), Tensor(variance) + infer_meta : + func : GroupNormInferMeta + kernel : + func : group_norm + optional : scale, bias + intermediate : mean, variance + backward : group_norm_grad + - api : gumbel_softmax args : (Tensor x, float temperature, bool hard, int axis) output : Tensor @@ -1039,6 +1050,7 @@ func : instance_norm data_type : x optional : scale, bias + intermediate : saved_mean, saved_variance backward : instance_norm_grad # is_empty diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 6a555fd24a0..81641ac19f7 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -844,6 +844,19 @@ data_type : out_grad optional: out, dst_count +- backward_api : group_norm_grad + forward : group_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int groups, str data_layout) -> Tensor(y), Tensor(mean), Tensor(variance) + args : (Tensor x, Tensor scale, Tensor bias, Tensor y, Tensor mean, Tensor variance, Tensor y_grad, float epsilon, int groups, str data_layout) + output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [y, scale, bias] + kernel : + func : group_norm_grad + data_type : y_grad + optional: scale, bias + inplace : (y_grad -> x_grad) + - backward_api : gumbel_softmax_grad forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out) args : (Tensor out, Tensor out_grad, int axis) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index 75533311513..c788128c63c 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm", "instance_norm"], +"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm", "instance_norm", "group_norm"], "phi_kernels":["equal_all"] } -- GitLab