未验证 提交 99c6497b 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi]Move group op kernel into PHI and add yaml / unittest (#43104)

* move_group_norm

* move group norm backward

* fix code format

* modify code according comment
上级 8bd3514c
......@@ -12,13 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/group_norm_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
......@@ -29,91 +33,6 @@ using DataLayout = framework::DataLayout;
class GroupNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GroupNorm");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "GroupNorm");
OP_INOUT_CHECK(ctx->HasOutput("Mean"), "Output", "Mean", "GroupNorm");
OP_INOUT_CHECK(ctx->HasOutput("Variance"), "Output", "Variance",
"GroupNorm");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(
x_dim.size(), 2,
platform::errors::InvalidArgument(
"The Input(X)'s dimension of Op(group_norm) must be "
"greater than 1. But received: %u-D Tensor, which shape is [%s].",
x_dim.size(), x_dim));
const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_layout");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const int64_t channel_num =
(data_layout == DataLayout::kNCHW ? x_dim[1] : x_dim[x_dim.size() - 1]);
auto batch_size = x_dim[0];
auto groups = ctx->Attrs().Get<int>("groups");
PADDLE_ENFORCE_LE(
groups, channel_num,
platform::errors::InvalidArgument(
"The Attr(groups) of Op(group_norm) must be less than or "
"equal to the number of channels. But received: groups "
"is [%s], channels is [%s], the Attr(data_layout) "
"is [%s]. The error may come from wrong data_layout setting.",
groups, channel_num, data_layout_str));
PADDLE_ENFORCE_GE(
groups, 1,
platform::errors::InvalidArgument(
"The Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s].",
groups));
PADDLE_ENFORCE_EQ(
channel_num % groups, 0,
platform::errors::InvalidArgument(
"Expected number of channels in input to be divisible by "
"num_groups, but got input channel is %d and num_groups is %d",
channel_num, groups));
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Scale").size(), 1UL,
platform::errors::InvalidArgument(
"The Input(Scale) of Op(group_norm) should be 1-D Tensor. "
"But received: %u-D Tensor, the shape of Input(Scale) is [%s].",
ctx->GetInputDim("Scale").size(), ctx->GetInputDim("Scale")));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Scale")[0], channel_num,
platform::errors::InvalidArgument(
"The Input(Scale)'s first dimension size of Op(group_norm) must "
"be equal to the number of channels. But received: the "
"Input(Scale)'s first dimension size is [%s], the channels is "
"[%s], the Attr(data_layout) is [%s]. The error may come "
"from wrong data_layout setting.",
ctx->GetInputDim("Scale")[0], channel_num, data_layout_str));
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Bias").size(), 1UL,
platform::errors::InvalidArgument(
"The Input(Bias) of Op(group_norm) should be 1-D Tensor. "
"But received: %u-D Tensor, the shape of Input(Bias) is [%s].",
ctx->GetInputDim("Bias").size(), ctx->GetInputDim("Bias")));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Bias")[0], channel_num,
platform::errors::InvalidArgument(
"The Input(Bias)'s first dimension size of "
"Op(group_norm) must be equal to the number of channels. "
"But received: the Input(Bias)'s first dimension size is [%s], "
"the channels is [%s], the Attr(data_layout) is [%s]. The "
"error may come from wrong data_layout setting.",
ctx->GetInputDim("Bias")[0], channel_num, data_layout_str));
}
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->SetOutputDim("Mean", {batch_size, groups});
ctx->SetOutputDim("Variance", {batch_size, groups});
ctx->ShareLoD("X", "Y");
}
};
class GroupNormOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -252,17 +171,14 @@ class GroupNormOpInferVarType
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(group_norm, GroupNormInferShapeFunctor,
PD_INFER_META(phi::GroupNormInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker,
ops::GroupNormOpInferVarType,
ops::GroupNormGradMaker<paddle::framework::OpDesc>,
ops::GroupNormGradMaker<paddle::imperative::OpBase>);
ops::GroupNormGradMaker<paddle::imperative::OpBase>,
GroupNormInferShapeFunctor);
REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp,
ops::GroupNormGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::GroupNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
group_norm_grad,
ops::GroupNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GroupNormGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -14,7 +14,8 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
......@@ -363,6 +364,122 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
}
}
void GroupNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
float epsilon,
int groups,
const std::string& data_layout_str,
MetaTensor* y,
MetaTensor* mean,
MetaTensor* variance) {
PADDLE_ENFORCE_NE(y,
nullptr,
phi::errors::InvalidArgument(
"The y in GroupNormInferMeta can't be nullptr."));
PADDLE_ENFORCE_NE(mean,
nullptr,
phi::errors::InvalidArgument(
"The mean in GroupNormInferMeta can't be nullptr."));
PADDLE_ENFORCE_NE(
variance,
nullptr,
phi::errors::InvalidArgument(
"The variance in GroupNormInferMeta can't be nullptr."));
auto x_dim = x.dims();
PADDLE_ENFORCE_GE(
x_dim.size(),
2,
phi::errors::InvalidArgument(
"The Input(X)'s dimension of Op(group_norm) must be "
"greater than 1. But received: %u-D Tensor, which shape is [%s].",
x_dim.size(),
x_dim));
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
const int64_t channel_num =
(data_layout == DataLayout::kNCHW ? x_dim[1] : x_dim[x_dim.size() - 1]);
auto batch_size = x_dim[0];
PADDLE_ENFORCE_LE(
groups,
channel_num,
phi::errors::InvalidArgument(
"The Attr(groups) of Op(group_norm) must be less than or "
"equal to the number of channels. But received: groups "
"is [%s], channels is [%s], the Attr(data_layout) "
"is [%s]. The error may come from wrong data_layout setting.",
groups,
channel_num,
data_layout_str));
PADDLE_ENFORCE_GE(
groups,
1,
phi::errors::InvalidArgument(
"The Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s].",
groups));
PADDLE_ENFORCE_EQ(
channel_num % groups,
0,
phi::errors::InvalidArgument(
"Expected number of channels in input to be divisible by "
"num_groups, but got input channel is %d and num_groups is %d",
channel_num,
groups));
if (scale) {
PADDLE_ENFORCE_EQ(
scale.dims().size(),
1UL,
phi::errors::InvalidArgument(
"The Input(Scale) of Op(group_norm) should be 1-D Tensor. "
"But received: %u-D Tensor, the shape of Input(Scale) is [%s].",
scale.dims().size(),
scale.dims()));
PADDLE_ENFORCE_EQ(
scale.dims()[0],
channel_num,
phi::errors::InvalidArgument(
"The Input(Scale)'s first dimension size of Op(group_norm) must "
"be equal to the number of channels. But received: the "
"Input(Scale)'s first dimension size is [%s], the channels is "
"[%s], the Attr(data_layout) is [%s]. The error may come "
"from wrong data_layout setting.",
scale.dims()[0],
channel_num,
data_layout_str));
}
if (bias) {
PADDLE_ENFORCE_EQ(
bias.dims().size(),
1UL,
phi::errors::InvalidArgument(
"The Input(Bias) of Op(group_norm) should be 1-D Tensor. "
"But received: %u-D Tensor, the shape of Input(Bias) is [%s].",
bias.dims().size(),
bias.dims()));
PADDLE_ENFORCE_EQ(
bias.dims()[0],
channel_num,
phi::errors::InvalidArgument(
"The Input(Bias)'s first dimension size of "
"Op(group_norm) must be equal to the number of channels. "
"But received: the Input(Bias)'s first dimension size is [%s], "
"the channels is [%s], the Attr(data_layout) is [%s]. The "
"error may come from wrong data_layout setting.",
bias.dims()[0],
channel_num,
data_layout_str));
}
y->set_dims(x_dim);
y->set_dtype(x.dtype());
y->share_lod(x);
mean->set_dims({batch_size, groups});
variance->set_dims({batch_size, groups});
}
void LayerNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
......
......@@ -69,6 +69,16 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* dst_count);
void GroupNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
float epsilon,
int groups,
const std::string& data_layout,
MetaTensor* y,
MetaTensor* mean,
MetaTensor* variance);
void LayerNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/group_norm_grad_kernel.h"
#include <algorithm>
#include <array>
#include <numeric>
#include <string>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void GroupNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& y,
const DenseTensor& mean,
const DenseTensor& var,
const DenseTensor& d_y,
float epsilon,
int groups,
const std::string& data_layout_str,
DenseTensor* d_x,
DenseTensor* d_scale,
DenseTensor* d_bias) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
const auto& x_dims = y.dims();
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = C / groups;
dev_ctx.template Alloc<T>(d_x);
phi::funcs::SetConstant<CPUContext, T> set_zero;
auto* x_data = y.data<T>();
auto* d_x_data = d_x->data<T>();
auto* y_data = d_y.data<T>();
auto* var_data = var.data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
dev_ctx.template Alloc<T>(d_scale);
set_zero(dev_ctx, d_scale, static_cast<T>(0));
d_scale_data = d_scale->data<T>();
}
T* d_bias_data = nullptr;
if (d_bias) {
dev_ctx.template Alloc<T>(d_bias);
set_zero(dev_ctx, d_bias, static_cast<T>(0));
d_bias_data = d_bias->data<T>();
}
const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data;
auto* iter_d_x_data = d_x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) {
for (int gid = 0; gid < groups; gid++) {
T x_var = var_data[bid * groups + gid];
T var_inv = 1.0 / sqrt(x_var + epsilon);
int number = std::min(group_size, static_cast<int>(C - gid * group_size));
T number_inv = 1.0 / (number * imsize);
auto* tmp_x = iter_x_data;
auto* tmp_y = iter_y_data;
auto* tmp_d_x = iter_d_x_data;
auto* x_src_data = iter_x_data;
auto* y_src_data = iter_y_data;
auto* iter_x_data_backup = iter_x_data;
auto* iter_y_data_backup = iter_y_data;
auto* iter_d_x_data_backup = iter_d_x_data;
T dp_scale = 0, dp_bias = 0;
if (data_layout == DataLayout::kNCHW) {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize;
imid++, iter_x_data++, iter_y_data++) {
T val = iter_x_data[0];
if (bias_data) val -= bias_data[gid * group_size + cid];
T dval = iter_y_data[0];
dp_scale += val * dval;
if (scale_data)
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data && scale_data[gid * group_size + cid] != 0)
val /= scale_data[gid * group_size + cid];
if (d_bias_data) d_bias_data[gid * group_size + cid] += dval;
if (d_scale_data)
d_scale_data[gid * group_size + cid] += val * dval;
}
}
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize;
imid++, iter_d_x_data++, tmp_x++, tmp_y++) {
T v_y = tmp_x[0];
T dly = tmp_y[0];
T dss = dp_scale;
T dbs = dp_bias;
T v_scale = 1., v_bias = 0.;
if (scale_data) v_scale = scale_data[gid * group_size + cid];
if (bias_data) v_bias = bias_data[gid * group_size + cid];
v_y -= v_bias;
if (v_scale != 0) v_y /= v_scale;
iter_d_x_data[0] =
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
var_inv;
}
}
} else {
for (int cid = 0; cid < number; cid++) {
iter_x_data = x_src_data + cid;
iter_y_data = y_src_data + cid;
for (int imid = 0; imid < imsize;
imid++, iter_x_data += C, iter_y_data += C) {
T val = iter_x_data[0];
if (bias_data) val -= bias_data[gid * group_size + cid];
T dval = iter_y_data[0];
dp_scale += val * dval;
if (scale_data)
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data && scale_data[gid * group_size + cid] != 0)
val /= scale_data[gid * group_size + cid];
if (d_bias_data) d_bias_data[gid * group_size + cid] += dval;
if (d_scale_data)
d_scale_data[gid * group_size + cid] += val * dval;
}
}
for (int cid = 0; cid < number; cid++) {
tmp_x = x_src_data + cid;
tmp_y = y_src_data + cid;
iter_d_x_data = tmp_d_x + cid;
for (int imid = 0; imid < imsize;
imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) {
T v_y = tmp_x[0];
T dly = tmp_y[0];
T dss = dp_scale;
T dbs = dp_bias;
T v_scale = 1.0, v_bias = 0.;
if (scale_data) v_scale = scale_data[gid * group_size + cid];
if (bias_data) v_bias = bias_data[gid * group_size + cid];
v_y -= v_bias;
if (v_scale != 0) v_y /= v_scale;
iter_d_x_data[0] =
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
var_inv;
}
}
iter_x_data = iter_x_data_backup + group_size;
iter_y_data = iter_y_data_backup + group_size;
iter_d_x_data = iter_d_x_data_backup + group_size;
}
}
if (data_layout == DataLayout::kNHWC) {
iter_x_data = x_data + (bid + 1) * C * imsize;
iter_d_x_data = d_x_data + (bid + 1) * C * imsize;
iter_y_data = y_data + (bid + 1) * C * imsize;
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
group_norm_grad, CPU, ALL_LAYOUT, phi::GroupNormGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/group_norm_kernel.h"
#include <algorithm>
#include <array>
#include <numeric>
#include <string>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void GroupNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int groups,
const std::string& data_layout_str,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
const auto x_dims = x.dims();
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = C / groups;
dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<T>(mean);
dev_ctx.template Alloc<T>(var);
auto* x_data = x.data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) {
for (int gid = 0; gid < groups; gid++) {
const int64_t M = 8;
std::array<T, M> x_mean_arr;
std::array<T, M> x_var_arr;
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
T x_mean = 0, x_var = 0;
int number = std::min(group_size, static_cast<int>(C - gid * group_size));
auto* tmp_x = iter_x_data;
auto* x_src_data = iter_x_data;
auto* tmp_y = iter_y_data;
auto* y_src_data = iter_y_data;
if (data_layout == DataLayout::kNCHW) {
for (int cid = 0; cid < number; cid++) {
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M) {
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
x_mean_arr[0] += iter_x_data[0];
x_var_arr[0] += iter_x_data[0] * iter_x_data[0];
x_mean_arr[1] += iter_x_data[1];
x_var_arr[1] += iter_x_data[1] * iter_x_data[1];
x_mean_arr[2] += iter_x_data[2];
x_var_arr[2] += iter_x_data[2] * iter_x_data[2];
x_mean_arr[3] += iter_x_data[3];
x_var_arr[3] += iter_x_data[3] * iter_x_data[3];
x_mean_arr[4] += iter_x_data[4];
x_var_arr[4] += iter_x_data[4] * iter_x_data[4];
x_mean_arr[5] += iter_x_data[5];
x_var_arr[5] += iter_x_data[5] * iter_x_data[5];
x_mean_arr[6] += iter_x_data[6];
x_var_arr[6] += iter_x_data[6] * iter_x_data[6];
x_mean_arr[7] += iter_x_data[7];
x_var_arr[7] += iter_x_data[7] * iter_x_data[7];
}
x_mean =
std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean);
x_var = std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var);
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
for (; imid < imsize; imid++, iter_x_data++) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
} else {
for (int cid = 0; cid < number; cid++) {
iter_x_data = tmp_x + cid;
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M * C) {
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
x_mean_arr[0] += iter_x_data[0 * C];
x_var_arr[0] += iter_x_data[0 * C] * iter_x_data[0 * C];
x_mean_arr[1] += iter_x_data[1 * C];
x_var_arr[1] += iter_x_data[1 * C] * iter_x_data[1 * C];
x_mean_arr[2] += iter_x_data[2 * C];
x_var_arr[2] += iter_x_data[2 * C] * iter_x_data[2 * C];
x_mean_arr[3] += iter_x_data[3 * C];
x_var_arr[3] += iter_x_data[3 * C] * iter_x_data[3 * C];
x_mean_arr[4] += iter_x_data[4 * C];
x_var_arr[4] += iter_x_data[4 * C] * iter_x_data[4 * C];
x_mean_arr[5] += iter_x_data[5 * C];
x_var_arr[5] += iter_x_data[5 * C] * iter_x_data[5 * C];
x_mean_arr[6] += iter_x_data[6 * C];
x_var_arr[6] += iter_x_data[6 * C] * iter_x_data[6 * C];
x_mean_arr[7] += iter_x_data[7 * C];
x_var_arr[7] += iter_x_data[7 * C] * iter_x_data[7 * C];
}
x_mean =
std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean);
x_var = std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var);
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
for (; imid < imsize; imid++, iter_x_data += C) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
iter_x_data = tmp_x + group_size;
}
x_mean /= number * imsize;
x_var /= number * imsize;
x_var = std::max(x_var - x_mean * x_mean, T(0));
T var_inv = T(1) / std::sqrt(x_var + epsilon);
mean_data[bid * groups + gid] = x_mean;
var_data[bid * groups + gid] = x_var;
if (data_layout == DataLayout::kNCHW) {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, tmp_x++, iter_y_data++) {
T val = (tmp_x[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
} else {
for (int cid = 0; cid < number; cid++) {
tmp_x = x_src_data + cid;
iter_y_data = y_src_data + cid;
for (int imid = 0; imid < imsize;
imid++, tmp_x += C, iter_y_data += C) {
T val = (tmp_x[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
iter_y_data = tmp_y + group_size;
}
}
if (data_layout == DataLayout::kNHWC) {
iter_x_data = x_data + (bid + 1) * C * imsize;
iter_y_data = y_data + (bid + 1) * C * imsize;
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
group_norm, CPU, ALL_LAYOUT, phi::GroupNormKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/group_norm_utils.h"
#include "paddle/phi/kernels/group_norm_grad_kernel.h"
namespace phi {
template <typename T, int flags>
__global__ void GroupNormBackwardGetMeanAndVar(const T* x,
const T* scale,
const T* bias,
const T* d_y,
int N,
int C,
int W,
int imsize,
int groups,
int group_size,
T epsilon,
T* d_mean,
T* d_var,
T* d_scale,
T* d_bias) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int H = imsize / W;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
T x_scale_inv = 0;
if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val, dval;
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias;
dval = d_y[(bid * H + hid) * W * C + wid * C + ccid];
d_var_data += val * dval;
d_mean_data += dval * x_scale;
val = val * x_scale_inv;
d_bias_data += dval;
d_scale_data += val * dval;
}
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data);
CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data);
if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
}
template <typename T, int flags>
__global__ void GroupNormBackward(const T* x,
const T* d_y,
const T* scale,
const T* bias,
const T* var,
const T* d_mean,
const T* d_var,
int N,
int C,
int W,
int imsize,
int groups,
int group_size,
T epsilon,
T* d_x) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int H = imsize / W;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_var = var[bid * groups + gid];
T d_x_mean = d_mean[bid * groups + gid];
T d_x_var = d_var[bid * groups + gid];
T x_var_inv = 1.0 / sqrt(x_var + epsilon);
T number_inv = 1.0 / (number * imsize);
T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
T x_scale_inv = 0;
if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
int hid = imid / W;
int wid = imid % W;
T tmp = x[(bid * H + hid) * W * C + wid * C + ccid];
T v_y = (tmp - x_bias) * x_scale_inv;
T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid];
d_x[(bid * H + hid) * W * C + wid * C + ccid] =
x_var_inv *
(dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
}
}
template <typename T>
__global__ void ScalarGetDsDbCUDAKernel(
int imsize, const T* x, const T* dy, T* ds, T* db) {
const int nc = blockIdx.x;
T ds_sum = 0;
T db_sum = 0;
for (int i = threadIdx.x; i < imsize; i += blockDim.x) {
const int index = nc * imsize + i;
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1);
}
template <typename T>
__global__ void GetScaleBiasGradientCUDAKernel(int N,
int C,
int group,
T epsilon,
const T* mean,
const T* var,
const T* ds,
const T* db,
T* d_scale,
T* d_bias) {
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < C) {
const int G = group;
const int D = C / G;
T sum1 = 0;
T sum2 = 0;
for (int n = 0; n < N; ++n) {
const int nc = n * C + c;
const int ng = n * G + c / D;
sum1 += (d_scale == nullptr)
? T(0)
: ((ds[nc] - db[nc] * static_cast<T>(mean[ng])) *
static_cast<T>(rsqrt(var[ng] + epsilon)));
sum2 += (d_bias == nullptr) ? T(0) : db[nc];
}
if (d_scale != nullptr) {
d_scale[c] = sum1;
}
if (d_bias != nullptr) {
d_bias[c] = sum2;
}
}
}
template <typename T, int BlockDim>
__global__ void GetBackwardParamsCUDAKernel(int imsize,
int groups,
int group_size,
T epsilon,
const T* mean,
const T* var,
const T* scale,
const T* ds,
const T* db,
T* p1,
T* p2,
T* p3) {
const int n = blockIdx.x;
const int g = blockIdx.y;
const int ng = n * groups + g;
T sum1 = 0;
T sum2 = 0;
T var_inv = rsqrt(var[ng] + epsilon);
for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) {
const int64_t index = ng * group_size + i;
const int64_t c = g * group_size + i;
const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]);
sum1 += ds[index] * scale_v;
sum2 += db[index] * scale_v;
const T scale_c = scale == nullptr ? T(0) : static_cast<T>(scale[c]);
p1[index] = scale_c * var_inv;
}
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
sum1 = BlockReduce(ds_storage).Reduce(sum1, cub::Sum());
sum2 = BlockReduce(db_storage).Reduce(sum2, cub::Sum());
if (threadIdx.x == 0) {
const T s = T(1) / static_cast<T>(group_size * imsize);
const T x = (sum2 * static_cast<T>(mean[ng]) - sum1) *
static_cast<T>(var_inv) * static_cast<T>(var_inv) *
static_cast<T>(var_inv) * s;
p2[ng] = x;
p3[ng] = -x * static_cast<T>(mean[ng]) - sum2 * static_cast<T>(var_inv) * s;
}
}
template <typename T>
__global__ void GetXGradientCUDAKernel(int imsize,
int C,
int group_size,
int groups,
T* p1,
T* p2,
T* p3,
const T* x,
const T* dy,
T* dx) {
int cid = blockIdx.x;
int gid = blockIdx.y;
int bid = blockIdx.z;
int ccid = bid * C + gid * group_size + cid;
int ng = bid * groups + gid;
int nc = gid * group_size + cid;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
int index = (bid * C + nc) * imsize + imid;
dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng];
}
}
template <typename T, typename Context>
void GroupNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& y,
const DenseTensor& mean,
const DenseTensor& var,
const DenseTensor& d_y,
float epsilon,
int groups,
const std::string& data_layout_str,
DenseTensor* d_x,
DenseTensor* d_scale,
DenseTensor* d_bias) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
const auto& x_dims = x.dims();
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = C / groups;
const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
dev_ctx.template Alloc<T>(d_x);
phi::funcs::SetConstant<GPUContext, T> set_zero;
DenseTensor ds, db;
ds.Resize({x_dims[0], C});
T* ds_data = dev_ctx.template Alloc<T>(&ds);
db.Resize({x_dims[0], C});
T* db_data = dev_ctx.template Alloc<T>(&db);
auto* y_data = y.data<T>();
auto* x_data = x.data<T>();
T* d_x_data = nullptr;
if (d_x) d_x_data = d_x->data<T>();
auto* dy_data = d_y.data<T>();
auto* var_data = var.data<T>();
auto* mean_data = mean.data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
dev_ctx.template Alloc<T>(d_scale);
d_scale_data = d_scale->data<T>();
}
T* d_bias_data = nullptr;
if (d_bias) {
dev_ctx.template Alloc<T>(d_bias);
d_bias_data = d_bias->data<T>();
}
const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
const int block_dims = 256;
#else
int block_size = std::min(1024, imsize);
const int block_dims = 1024;
#endif
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
if (data_layout == DataLayout::kNCHW) {
const int max_num_threads = 1024;
int max_block_size = std::min(imsize, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 blocks(block_size_nchw);
ScalarGetDsDbCUDAKernel<T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
if (d_scale || d_bias) {
const int block = 256;
GetScaleBiasGradientCUDAKernel<T>
<<<(C + block - 1) / block, block, 0, dev_ctx.stream()>>>(
x_dims[0],
C,
groups,
epsilon,
mean_data,
var_data,
ds_data,
db_data,
d_scale_data,
d_bias_data);
}
if (d_x_data != nullptr) {
// p1 * dy + p2 * x + p3,
// p1, p2, p3 represent the reverse calculation of temporary variables
// p1 = scale * var_inv
// p2 = (db * scale * mean - ds * scale) * pow(var_inv, 3) * (1/n)
// p3 = -p2 * mean[ng] - db * scale * var_inv * (1/n);
DenseTensor p1, p2, p3;
p1.Resize({x_dims[0] * C});
T* p1_data = dev_ctx.template Alloc<T>(&p1);
p2.Resize({x_dims[0], groups});
T* p2_data = dev_ctx.template Alloc<T>(&p2);
p3.Resize({x_dims[0], groups});
T* p3_data = dev_ctx.template Alloc<T>(&p3);
GetBackwardParamsCUDAKernel<T, block_dims>
<<<dim3(x_dims[0], groups), block_dims, 0, dev_ctx.stream()>>>(
imsize,
groups,
group_size,
epsilon,
mean_data,
var_data,
scale_data,
ds_data,
db_data,
p1_data,
p2_data,
p3_data);
GetXGradientCUDAKernel<T>
<<<grid, threads, 0, dev_ctx.stream()>>>(imsize,
C,
group_size,
groups,
p1_data,
p2_data,
p3_data,
x_data,
dy_data,
d_x_data);
}
} else {
if (d_scale) {
set_zero(dev_ctx, d_scale, static_cast<T>(0));
}
if (d_bias) {
set_zero(dev_ctx, d_bias, static_cast<T>(0));
}
DenseTensor temp_var;
temp_var.Resize(var.dims());
dev_ctx.template Alloc<T>(&temp_var);
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
T* temp_var_data = temp_var.data<T>();
DenseTensor temp_mean;
temp_mean.Resize(var.dims());
dev_ctx.template Alloc<T>(&temp_mean);
set_zero(dev_ctx, &temp_mean, static_cast<T>(0));
T* temp_mean_data = temp_mean.data<T>();
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags,
GroupNormBackwardGetMeanAndVar,
y_data,
scale_data,
bias_data,
dy_data,
x_dims[0],
C,
W,
imsize,
groups,
group_size,
epsilon,
temp_mean_data,
temp_var_data,
d_scale_data,
d_bias_data);
if (d_x_data != nullptr) {
UNROLL_ALL_CASES(flags,
GroupNormBackward,
y_data,
dy_data,
scale_data,
bias_data,
var_data,
temp_mean_data,
temp_var_data,
x_dims[0],
C,
W,
imsize,
groups,
group_size,
epsilon,
d_x_data);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
group_norm_grad, GPU, ALL_LAYOUT, phi::GroupNormGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/group_norm_utils.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
namespace phi {
template <typename T>
__global__ void GroupNormForwardGetMeanAndVar(const T* x,
int N,
int C,
int W,
int imsize,
int groups,
int group_size,
T* mean,
T* var) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int H = imsize / W;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = 0, x_var = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val;
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];
x_mean += val;
x_var += val * val;
}
x_mean /= number * imsize;
x_var /= number * imsize;
CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean);
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
}
template <typename T, int flags>
__global__ void GroupNormForward(const T* x,
const T* mean,
const T* var,
const T* scale,
const T* bias,
int N,
int C,
int W,
int imsize,
int groups,
int group_size,
T epsilon,
T* y,
T* real_var,
const DataLayout data_layout) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int H = imsize / W;
int ccid = gid * group_size + cid;
if (ccid >= C) return;
auto ng = bid * groups + gid;
T x_mean = mean[ng];
T x_var = var[ng];
x_var = x_var - x_mean * x_mean;
T var_inv = rsqrt(x_var + epsilon);
if (cid == 0 && threadIdx.x == 0) {
real_var[ng] = x_var;
}
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val;
int hid, wid;
int index = (bid * C + ccid) * imsize + imid;
if (data_layout == DataLayout::kNCHW) {
val = x[index];
} else {
hid = imid / W;
wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];
}
val = (val - x_mean) * var_inv;
if (flags & kHasScale) {
val *= scale[ccid];
}
if (flags & kHasBias) {
val += bias[ccid];
}
if (data_layout == DataLayout::kNCHW) {
y[index] = val;
} else {
y[(bid * H + hid) * W * C + wid * C + ccid] = val;
}
}
}
template <typename T, typename Context>
void GroupNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int groups,
const std::string& data_layout_str,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
const auto x_dims = x.dims();
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = C / groups;
const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<T>(mean);
dev_ctx.template Alloc<T>(var);
phi::funcs::SetConstant<GPUContext, T> set_zero;
DenseTensor temp_var;
temp_var.Resize(var->dims());
dev_ctx.template Alloc<T>(&temp_var);
auto* x_data = x.data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
auto* temp_var_data = temp_var.data<T>();
const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
#else
int block_size = std::min(1024, imsize);
#endif
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
if (data_layout == DataLayout::kNCHW) {
using AccT = typename kps::details::MPTypeTrait<T>::Type;
constexpr int vec_size = sizeof(float4) / sizeof(T);
int size = group_size * imsize;
const int max_num_threads = 1024;
int max_block_size = std::min(size / vec_size, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 grids(x_dims[0] * groups);
dim3 blocks(block_size_nchw);
if (size < vec_size * block_size_nchw) {
ScalarGetMeanAndVarNCHW<T><<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size);
} else {
VectorizedGetMeanAndVarNCHW<T, AccT, vec_size>
<<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size);
}
} else {
set_zero(dev_ctx, mean, static_cast<T>(0));
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
GroupNormForwardGetMeanAndVar<T>
<<<grid, threads, 0, dev_ctx.stream()>>>(x_data,
x_dims[0],
C,
W,
imsize,
groups,
group_size,
mean_data,
temp_var_data);
}
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags,
GroupNormForward,
x_data,
mean_data,
temp_var_data,
scale_data,
bias_data,
x_dims[0],
C,
W,
imsize,
groups,
group_size,
epsilon,
y_data,
var_data,
data_layout);
}
} // namespace phi
PD_REGISTER_KERNEL(
group_norm, GPU, ALL_LAYOUT, phi::GroupNormKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace phi {
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
#define ALIGN_BYTES 16
#define CHECK_CASE(i, flags, kernel_name, ...) \
if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
}
// 0 for no scale, no bias
// 1 for has scale, no bias
// 2 for no scale, has bias
// 3 for has scale, has bias
#define UNROLL_ALL_CASES(flags, kernel_name, ...) \
CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(3, flags, kernel_name, __VA_ARGS__)
template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage;
value = WarpReduce(temp_storage).Sum(value);
if (cub::LaneId() == 0) paddle::platform::CudaAtomicAdd(sum, value);
}
template <typename T, typename AccT, int VecSize, int Num>
__device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
int size,
const int offset,
AccT* out_mean,
AccT* out_var) {
const T* x = arrs[0];
const T* y;
if (Num == 2) {
y = arrs[1];
}
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
if (offset > 0) {
x -= offset;
if (Num == 2) {
y -= offset;
}
size += offset;
if (tid >= offset) {
if (Num == 1) {
*out_mean += x[tid];
*out_var += x[tid] * x[tid];
} else if (Num == 2) {
*out_mean += y[tid];
*out_var += y[tid] * x[tid];
}
}
size -= blockDim.x;
x += blockDim.x;
if (Num == 2) {
y += blockDim.x;
}
}
int remain = size % (VecSize * blockDim.x);
T ins_x[VecSize];
T ins_y[VecSize];
VecT* ins_vec_x = reinterpret_cast<VecT*>(&ins_x);
VecT* ins_vec_y = reinterpret_cast<VecT*>(&ins_y);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
*ins_vec_x = reinterpret_cast<const VecT*>(x)[tid];
if (Num == 2) {
*ins_vec_y = reinterpret_cast<const VecT*>(y)[tid];
}
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
if (Num == 1) {
*out_mean += ins_x[i];
*out_var += ins_x[i] * ins_x[i];
} else if (Num == 2) {
*out_mean += ins_y[i];
*out_var += ins_y[i] * ins_x[i];
}
}
}
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
if (Num == 1) {
*out_mean += x[tid];
*out_var += x[tid] * x[tid];
} else if (Num == 2) {
*out_mean += y[tid];
*out_var += y[tid] * x[tid];
}
}
}
template <typename T>
__device__ __forceinline__ void ReduceMeanAndVar(
T* mean, T* var, T x_mean, T x_var, int size) {
const int nc = blockIdx.x;
x_mean = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
x_mean, kps::AddFunctor<T>());
x_var = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
x_var, kps::AddFunctor<T>());
__syncthreads();
if (threadIdx.x == 0) {
mean[nc] = static_cast<T>(x_mean / size);
var[nc] = static_cast<T>(x_var / size);
}
}
template <typename T>
__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
int i = blockIdx.x;
T x_mean = 0, x_var = 0;
for (int j = threadIdx.x; j < size; j += blockDim.x) {
T val;
val = x[i * size + j];
x_mean += val;
x_var += val * val;
}
ReduceMeanAndVar<T>(mean, var, x_mean, x_var, size);
}
template <typename T, typename AccT, int VecSize>
__global__ void VectorizedGetMeanAndVarNCHW(const T* x,
T* mean,
T* var,
int size) {
int i = blockIdx.x;
AccT x_mean = static_cast<AccT>(0);
AccT x_var = static_cast<AccT>(0);
x += i * size;
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
phi::Array<const T*, 1> ins;
ins[0] = x;
ThreadReduce<T, AccT, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var);
ReduceMeanAndVar<AccT>(mean, var, x_mean, x_var, size);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GroupNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& y,
const DenseTensor& mean,
const DenseTensor& variance,
const DenseTensor& d_y,
float epsilon,
int groups,
const std::string& data_layout,
DenseTensor* d_x,
DenseTensor* d_scale,
DenseTensor* d_bias);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GroupNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int groups,
const std::string& data_layout,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* variance);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GroupNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("group_norm",
{"X", "Scale", "Bias"},
{"epsilon", "groups", "data_layout"},
{"Y", "Mean", "Variance"});
}
KernelSignature GroupNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"group_norm_grad",
{"X", "Scale", "Bias", "Y", "Mean", "Variance", "Y@GRAD"},
{"epsilon", "groups", "data_layout"},
{"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(group_norm, phi::GroupNormOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(group_norm_grad,
phi::GroupNormGradOpArgumentMapping);
......@@ -1144,8 +1144,8 @@ class InstanceNorm(layers.Layer):
def forward(self, input):
if in_dygraph_mode():
out, _, _, = _C_ops.final_state_instance_norm(
input, self.scale, self.bias, self._epsilon)
out = _C_ops.final_state_instance_norm(input, self.scale, self.bias,
self._epsilon)
return out
if _in_legacy_dygraph():
out, _, _ = _C_ops.instance_norm(input, self.scale, self.bias,
......@@ -3031,8 +3031,14 @@ class GroupNorm(layers.Layer):
dtype=self._dtype, stop_gradient=True)
variance_out = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
if in_dygraph_mode():
out = _C_ops.final_state_group_norm(input, self.weight, self.bias,
self._epsilon, self._groups,
"NCHW")
if _non_static_mode():
return dygraph_utils._append_activation_in_dygraph(out, self._act)
elif _in_legacy_dygraph():
attrs = ('epsilon', self._epsilon, 'groups', self._groups)
out, _, _ = _C_ops.group_norm(input, self.weight, self.bias,
mean_out, variance_out, *attrs)
......
......@@ -20,7 +20,7 @@ from operator import mul
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci
from paddle.fluid.framework import _test_eager_guard
from testsuite import create_op
......@@ -301,5 +301,30 @@ class TestGroupNormException(unittest.TestCase):
self.assertRaises(ValueError, attr_data_format)
class TestGroupNormEager(unittest.TestCase):
def test_dygraph_final_state_api(self):
self.dtype = np.float64
self.shape = (8, 32, 32)
input = np.random.random(self.shape).astype(self.dtype)
with fluid.dygraph.guard():
tensor_1 = fluid.dygraph.to_variable(input)
tensor_1.stop_gradient = False
groupNorm = fluid.dygraph.nn.GroupNorm(channels=32, groups=4)
ret1 = groupNorm(tensor_1)
ret1.backward()
with _test_eager_guard():
tensor_eager_1 = fluid.dygraph.to_variable(input)
tensor_eager_1.stop_gradient = False
groupNorm_eager = fluid.dygraph.nn.GroupNorm(channels=32,
groups=4)
ret2 = groupNorm_eager(tensor_eager_1)
ret2.backward()
self.assertEqual((
tensor_1.grad.numpy() == tensor_eager_1.grad.numpy()).all(),
True)
if __name__ == '__main__':
unittest.main()
......@@ -22,6 +22,7 @@ from op_test import OpTest, _set_use_system_allocator
from paddle.fluid.framework import grad_var_name
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
import paddle
......@@ -124,6 +125,10 @@ class TestDygraphGroupNormv2(unittest.TestCase):
y2 = compute_v2(x)
self.assertTrue(np.allclose(y1, y2, atol=1e-5))
def test_eager_api(self):
with _test_eager_guard():
self.test_dygraph()
class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase):
......@@ -154,6 +159,10 @@ class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase):
self.assertTrue(np.allclose(result1, expect_res1, atol=1e-5))
self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5))
def test_eager_api(self):
with _test_eager_guard():
self.test_numerical_accuracy()
class TestGroupNormDimException(unittest.TestCase):
......
......@@ -413,7 +413,7 @@ def instance_norm(x,
"""
if in_dygraph_mode():
out, _, _, = _C_ops.final_state_instance_norm(x, weight, bias, eps)
out = _C_ops.final_state_instance_norm(x, weight, bias, eps)
return out
if _in_legacy_dygraph():
out, _, _ = _C_ops.instance_norm(x, weight, bias, "epsilon", eps,
......
......@@ -933,6 +933,17 @@
kernel :
func : greater_than
- api : group_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int groups, str data_layout)
output : Tensor(y), Tensor(mean), Tensor(variance)
infer_meta :
func : GroupNormInferMeta
kernel :
func : group_norm
optional : scale, bias
intermediate : mean, variance
backward : group_norm_grad
- api : gumbel_softmax
args : (Tensor x, float temperature, bool hard, int axis)
output : Tensor
......@@ -1039,6 +1050,7 @@
func : instance_norm
data_type : x
optional : scale, bias
intermediate : saved_mean, saved_variance
backward : instance_norm_grad
# is_empty
......
......@@ -844,6 +844,19 @@
data_type : out_grad
optional: out, dst_count
- backward_api : group_norm_grad
forward : group_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int groups, str data_layout) -> Tensor(y), Tensor(mean), Tensor(variance)
args : (Tensor x, Tensor scale, Tensor bias, Tensor y, Tensor mean, Tensor variance, Tensor y_grad, float epsilon, int groups, str data_layout)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [y, scale, bias]
kernel :
func : group_norm_grad
data_type : y_grad
optional: scale, bias
inplace : (y_grad -> x_grad)
- backward_api : gumbel_softmax_grad
forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out)
args : (Tensor out, Tensor out_grad, int axis)
......
{
"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm", "instance_norm"],
"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm", "instance_norm", "group_norm"],
"phi_kernels":["equal_all"]
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册