未验证 提交 e93e8a3f 编写于 作者: H huangjiyi 提交者: GitHub

update (#52878)

上级 aac8da90
...@@ -53,7 +53,7 @@ class GetFloatStatusMaker : public framework::OpProtoAndCheckerMaker { ...@@ -53,7 +53,7 @@ class GetFloatStatusMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class GetFloatStatusKernel : public framework::OpKernel<T> { class GetFloatStatusKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -75,4 +75,5 @@ REGISTER_OPERATOR( ...@@ -75,4 +75,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(get_float_status, ops::GetFloatStatusKernel<CPU, float>); PD_REGISTER_STRUCT_KERNEL(
get_float_status, CPU, ALL_LAYOUT, ops::GetFloatStatusKernel, float) {}
...@@ -111,9 +111,12 @@ REGISTER_OPERATOR(global_gather, ...@@ -111,9 +111,12 @@ REGISTER_OPERATOR(global_gather,
ops::GlobalGatherOpGradMaker<paddle::framework::OpDesc>, ops::GlobalGatherOpGradMaker<paddle::framework::OpDesc>,
ops::GlobalGatherOpGradMaker<paddle::imperative::OpBase>) ops::GlobalGatherOpGradMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(global_gather, PD_REGISTER_STRUCT_KERNEL(global_gather,
ops::GlobalGatherOpCPUKernel<float>, CPU,
ops::GlobalGatherOpCPUKernel<double>, ALL_LAYOUT,
ops::GlobalGatherOpCPUKernel<int>, ops::GlobalGatherOpCPUKernel,
ops::GlobalGatherOpCPUKernel<int64_t>, float,
ops::GlobalGatherOpCPUKernel<plat::float16>); double,
int,
int64_t,
plat::float16) {}
...@@ -261,7 +261,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> { ...@@ -261,7 +261,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
} }
}; };
template <typename T> template <typename T, typename DeivceContext>
class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> { class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -283,9 +283,12 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -283,9 +283,12 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(global_gather, PD_REGISTER_STRUCT_KERNEL(global_gather,
ops::GlobalGatherOpCUDAKernel<float>, GPU,
ops::GlobalGatherOpCUDAKernel<double>, ALL_LAYOUT,
ops::GlobalGatherOpCUDAKernel<int>, ops::GlobalGatherOpCUDAKernel,
ops::GlobalGatherOpCUDAKernel<int64_t>, float,
ops::GlobalGatherOpCUDAKernel<plat::float16>); double,
int,
int64_t,
plat::float16) {}
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class GlobalGatherOpCPUKernel : public framework::OpKernel<T> { class GlobalGatherOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -115,9 +115,12 @@ REGISTER_OPERATOR(global_scatter, ...@@ -115,9 +115,12 @@ REGISTER_OPERATOR(global_scatter,
ops::GlobalScatterOpGradMaker<paddle::framework::OpDesc>, ops::GlobalScatterOpGradMaker<paddle::framework::OpDesc>,
ops::GlobalScatterOpGradMaker<paddle::imperative::OpBase>) ops::GlobalScatterOpGradMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(global_scatter, PD_REGISTER_STRUCT_KERNEL(global_scatter,
ops::GlobalScatterOpCPUKernel<float>, CPU,
ops::GlobalScatterOpCPUKernel<double>, ALL_LAYOUT,
ops::GlobalScatterOpCPUKernel<int>, ops::GlobalScatterOpCPUKernel,
ops::GlobalScatterOpCPUKernel<int64_t>, float,
ops::GlobalScatterOpCPUKernel<plat::float16>); double,
int,
int64_t,
plat::float16) {}
...@@ -259,7 +259,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> { ...@@ -259,7 +259,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> { class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -281,9 +281,12 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -281,9 +281,12 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(global_scatter, PD_REGISTER_STRUCT_KERNEL(global_scatter,
ops::GlobalScatterOpCUDAKernel<float>, GPU,
ops::GlobalScatterOpCUDAKernel<double>, ALL_LAYOUT,
ops::GlobalScatterOpCUDAKernel<int>, ops::GlobalScatterOpCUDAKernel,
ops::GlobalScatterOpCUDAKernel<int64_t>, float,
ops::GlobalScatterOpCUDAKernel<plat::float16>); double,
int,
int64_t,
plat::float16) {}
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class GlobalScatterOpCPUKernel : public framework::OpKernel<T> { class GlobalScatterOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -328,7 +328,7 @@ std::vector<phi::DenseTensor> SampleMaskForOneImage( ...@@ -328,7 +328,7 @@ std::vector<phi::DenseTensor> SampleMaskForOneImage(
return res; return res;
} }
template <typename T> template <typename T, typename DeviceContext>
class GenerateMaskLabelsKernel : public framework::OpKernel<T> { class GenerateMaskLabelsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -533,5 +533,9 @@ REGISTER_OPERATOR( ...@@ -533,5 +533,9 @@ REGISTER_OPERATOR(
ops::GenerateMaskLabelsOpMaker, ops::GenerateMaskLabelsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_mask_labels,
ops::GenerateMaskLabelsKernel<float>); PD_REGISTER_STRUCT_KERNEL(generate_mask_labels,
CPU,
ALL_LAYOUT,
ops::GenerateMaskLabelsKernel,
float) {}
...@@ -510,7 +510,7 @@ std::vector<phi::DenseTensor> SampleRoisForOneImage( ...@@ -510,7 +510,7 @@ std::vector<phi::DenseTensor> SampleRoisForOneImage(
return res; return res;
} }
template <typename T> template <typename T, typename DeviceContext>
class GenerateProposalLabelsKernel : public framework::OpKernel<T> { class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -811,9 +811,12 @@ REGISTER_OPERATOR( ...@@ -811,9 +811,12 @@ REGISTER_OPERATOR(
ops::GenerateProposalLabelsOpMaker, ops::GenerateProposalLabelsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_proposal_labels, PD_REGISTER_STRUCT_KERNEL(generate_proposal_labels,
ops::GenerateProposalLabelsKernel<float>, CPU,
ops::GenerateProposalLabelsKernel<double>); ALL_LAYOUT,
ops::GenerateProposalLabelsKernel,
float,
double) {}
REGISTER_OP_VERSION(generate_proposal_labels) REGISTER_OP_VERSION(generate_proposal_labels)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class CPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> { class CPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -99,7 +99,10 @@ REGISTER_OPERATOR( ...@@ -99,7 +99,10 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInferer); paddle::operators::BatchSizeLikeNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( namespace ops = paddle::operators;
gaussian_random_batch_size_like, PD_REGISTER_STRUCT_KERNEL(gaussian_random_batch_size_like,
paddle::operators::CPUGaussianRandomBatchSizeLikeKernel<float>, CPU,
paddle::operators::CPUGaussianRandomBatchSizeLikeKernel<double>); ALL_LAYOUT,
ops::CPUGaussianRandomBatchSizeLikeKernel,
float,
double) {}
...@@ -47,7 +47,7 @@ struct GaussianGenerator { ...@@ -47,7 +47,7 @@ struct GaussianGenerator {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> { class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -78,9 +78,12 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> { ...@@ -78,9 +78,12 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL( namespace ops = paddle::operators;
gaussian_random_batch_size_like, namespace plat = paddle::platform;
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel< PD_REGISTER_STRUCT_KERNEL(gaussian_random_batch_size_like,
paddle::platform::float16>, GPU,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<float>, ALL_LAYOUT,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<double>); ops::GPUGaussianRandomBatchSizeLikeKernel,
float,
double,
plat::float16) {}
...@@ -136,6 +136,10 @@ using CPU = phi::CPUContext; ...@@ -136,6 +136,10 @@ using CPU = phi::CPUContext;
REGISTER_OPERATOR(graph_khop_sampler, REGISTER_OPERATOR(graph_khop_sampler,
ops::GraphKhopSamplerOP, ops::GraphKhopSamplerOP,
ops::GraphKhopSamplerOpMaker); ops::GraphKhopSamplerOpMaker);
REGISTER_OP_CPU_KERNEL(graph_khop_sampler,
ops::GraphKhopSamplerOpKernel<CPU, int32_t>, PD_REGISTER_STRUCT_KERNEL(graph_khop_sampler,
ops::GraphKhopSamplerOpKernel<CPU, int64_t>); CPU,
ALL_LAYOUT,
ops::GraphKhopSamplerOpKernel,
int32_t,
int64_t) {}
...@@ -412,7 +412,7 @@ void ReindexFunc(const framework::ExecutionContext& ctx, ...@@ -412,7 +412,7 @@ void ReindexFunc(const framework::ExecutionContext& ctx,
thrust::raw_pointer_cast(values.data())); thrust::raw_pointer_cast(values.data()));
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class GraphKhopSamplerOpCUDAKernel : public framework::OpKernel<T> { class GraphKhopSamplerOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -668,6 +668,9 @@ class GraphKhopSamplerOpCUDAKernel : public framework::OpKernel<T> { ...@@ -668,6 +668,9 @@ class GraphKhopSamplerOpCUDAKernel : public framework::OpKernel<T> {
using CUDA = phi::GPUContext; using CUDA = phi::GPUContext;
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(graph_khop_sampler, PD_REGISTER_STRUCT_KERNEL(graph_khop_sampler,
ops::GraphKhopSamplerOpCUDAKernel<CUDA, int32_t>, GPU,
ops::GraphKhopSamplerOpCUDAKernel<CUDA, int64_t>); ALL_LAYOUT,
ops::GraphKhopSamplerOpCUDAKernel,
int32_t,
int64_t) {}
...@@ -191,7 +191,7 @@ void SampleNeighbors(const T* src, ...@@ -191,7 +191,7 @@ void SampleNeighbors(const T* src,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class GraphKhopSamplerOpKernel : public framework::OpKernel<T> { class GraphKhopSamplerOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/group_norm_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using DataLayout = phi::DataLayout;
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
#define ALIGN_BYTES 16
#define CHECK_CASE(i, flags, kernel_name, ...) \
if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
}
// 0 for no scale, no bias
// 1 for has scale, no bias
// 2 for no scale, has bias
// 3 for has scale, has bias
#define UNROLL_ALL_CASES(flags, kernel_name, ...) \
CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(3, flags, kernel_name, __VA_ARGS__)
template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage;
value = WarpReduce(temp_storage).Sum(value);
if (cub::LaneId() == 0) phi::CudaAtomicAdd(sum, value);
}
template <typename T>
__global__ void GroupNormForwardGetMeanAndVar(const T* x,
int N,
int C,
int W,
int imsize,
int groups,
int group_size,
T* mean,
T* var) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int H = imsize / W;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = 0, x_var = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val;
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];
x_mean += val;
x_var += val * val;
}
x_mean /= number * imsize;
x_var /= number * imsize;
CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean);
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
}
template <typename T, typename AccT, int VecSize, int Num>
__device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
int size,
const int offset,
AccT* out_mean,
AccT* out_var) {
const T* x = arrs[0];
const T* y;
if (Num == 2) {
y = arrs[1];
}
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
if (offset > 0) {
x -= offset;
if (Num == 2) {
y -= offset;
}
size += offset;
if (tid >= offset) {
if (Num == 1) {
*out_mean += x[tid];
*out_var += x[tid] * x[tid];
} else if (Num == 2) {
*out_mean += y[tid];
*out_var += y[tid] * x[tid];
}
}
size -= blockDim.x;
x += blockDim.x;
if (Num == 2) {
y += blockDim.x;
}
}
int remain = size % (VecSize * blockDim.x);
T ins_x[VecSize];
T ins_y[VecSize];
VecT* ins_vec_x = reinterpret_cast<VecT*>(&ins_x);
VecT* ins_vec_y = reinterpret_cast<VecT*>(&ins_y);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
*ins_vec_x = reinterpret_cast<const VecT*>(x)[tid];
if (Num == 2) {
*ins_vec_y = reinterpret_cast<const VecT*>(y)[tid];
}
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
if (Num == 1) {
*out_mean += ins_x[i];
*out_var += ins_x[i] * ins_x[i];
} else if (Num == 2) {
*out_mean += ins_y[i];
*out_var += ins_y[i] * ins_x[i];
}
}
}
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
if (Num == 1) {
*out_mean += x[tid];
*out_var += x[tid] * x[tid];
} else if (Num == 2) {
*out_mean += y[tid];
*out_var += y[tid] * x[tid];
}
}
}
template <typename T>
__device__ __forceinline__ void ReduceMeanAndVar(
T* mean, T* var, T x_mean, T x_var, int size) {
const int nc = blockIdx.x;
x_mean = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
x_mean, kps::AddFunctor<T>());
x_var = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
x_var, kps::AddFunctor<T>());
__syncthreads();
if (threadIdx.x == 0) {
mean[nc] = static_cast<T>(x_mean / size);
var[nc] = static_cast<T>(x_var / size);
}
}
template <typename T>
__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
int i = blockIdx.x;
T x_mean = 0, x_var = 0;
for (int j = threadIdx.x; j < size; j += blockDim.x) {
T val;
val = x[i * size + j];
x_mean += val;
x_var += val * val;
}
ReduceMeanAndVar<T>(mean, var, x_mean, x_var, size);
}
template <typename T, typename AccT, int VecSize>
__global__ void VectorizedGetMeanAndVarNCHW(const T* x,
T* mean,
T* var,
int size) {
int i = blockIdx.x;
AccT x_mean = static_cast<AccT>(0);
AccT x_var = static_cast<AccT>(0);
x += i * size;
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
phi::Array<const T*, 1> ins;
ins[0] = x;
ThreadReduce<T, AccT, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var);
ReduceMeanAndVar<AccT>(mean, var, x_mean, x_var, size);
}
template <typename T, int flags>
__global__ void GroupNormForward(const T* x,
const T* mean,
const T* var,
const T* scale,
const T* bias,
int N,
int C,
int W,
int imsize,
int groups,
int group_size,
T epsilon,
T* y,
T* real_var,
const DataLayout data_layout) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int H = imsize / W;
int ccid = gid * group_size + cid;
if (ccid >= C) return;
auto ng = bid * groups + gid;
T x_mean = mean[ng];
T x_var = var[ng];
x_var = x_var - x_mean * x_mean;
T var_inv = rsqrt(x_var + epsilon);
if (cid == 0 && threadIdx.x == 0) {
real_var[ng] = x_var;
}
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val;
int hid, wid;
int index = (bid * C + ccid) * imsize + imid;
if (data_layout == DataLayout::kNCHW) {
val = x[index];
} else {
hid = imid / W;
wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];
}
val = (val - x_mean) * var_inv;
if (flags & kHasScale) {
val *= scale[ccid];
}
if (flags & kHasBias) {
val += bias[ccid];
}
if (data_layout == DataLayout::kNCHW) {
y[index] = val;
} else {
y[(bid * H + hid) * W * C + wid * C + ccid] = val;
}
}
}
template <typename T>
class GroupNormKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const float epsilon = ctx.Attr<float>("epsilon");
auto* scale = ctx.Input<phi::DenseTensor>("Scale");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* y = ctx.Output<phi::DenseTensor>("Y");
auto* mean = ctx.Output<phi::DenseTensor>("Mean");
auto* var = ctx.Output<phi::DenseTensor>("Variance");
const auto groups = ctx.Attr<int>("groups");
const auto x_dims = x->dims();
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = C / groups;
const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<phi::GPUContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
phi::DenseTensor temp_var;
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
auto* x_data = x->data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
auto* temp_var_data = temp_var.data<T>();
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
#else
int block_size = std::min(1024, imsize);
#endif
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
if (data_layout == DataLayout::kNCHW) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
constexpr int vec_size = sizeof(float4) / sizeof(T);
int size = group_size * imsize;
const int max_num_threads = 1024;
int max_block_size = std::min(size / vec_size, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 grids(x_dims[0] * groups);
dim3 blocks(block_size_nchw);
if (size < vec_size * block_size_nchw) {
ScalarGetMeanAndVarNCHW<T><<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size);
} else {
VectorizedGetMeanAndVarNCHW<T, AccT, vec_size>
<<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size);
}
} else {
set_zero(dev_ctx, mean, static_cast<T>(0));
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
GroupNormForwardGetMeanAndVar<T>
<<<grid, threads, 0, dev_ctx.stream()>>>(x_data,
x_dims[0],
C,
W,
imsize,
groups,
group_size,
mean_data,
temp_var_data);
}
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags,
GroupNormForward,
x_data,
mean_data,
temp_var_data,
scale_data,
bias_data,
x_dims[0],
C,
W,
imsize,
groups,
group_size,
epsilon,
y_data,
var_data,
data_layout);
}
};
template <typename T, int flags>
__global__ void GroupNormBackwardGetMeanAndVar(const T* x,
const T* scale,
const T* bias,
const T* d_y,
int N,
int C,
int W,
int imsize,
int groups,
int group_size,
T epsilon,
T* d_mean,
T* d_var,
T* d_scale,
T* d_bias) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int H = imsize / W;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
T x_scale_inv = 0;
if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val, dval;
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias;
dval = d_y[(bid * H + hid) * W * C + wid * C + ccid];
d_var_data += val * dval;
d_mean_data += dval * x_scale;
val = val * x_scale_inv;
d_bias_data += dval;
d_scale_data += val * dval;
}
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data);
CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data);
if (flags & kHasScale) {
#if CUDA_VERSION >= 11070
phi::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
#else
CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
#endif
}
if (flags & kHasBias) {
#if CUDA_VERSION >= 11070
phi::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
#else
CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
#endif
}
}
template <typename T, int flags>
__global__ void GroupNormBackward(const T* x,
const T* d_y,
const T* scale,
const T* bias,
const T* var,
const T* d_mean,
const T* d_var,
int N,
int C,
int W,
int imsize,
int groups,
int group_size,
T epsilon,
T* d_x) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int H = imsize / W;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_var = var[bid * groups + gid];
T d_x_mean = d_mean[bid * groups + gid];
T d_x_var = d_var[bid * groups + gid];
T x_var_inv = 1.0 / sqrt(x_var + epsilon);
T number_inv = 1.0 / (number * imsize);
T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
T x_scale_inv = 0;
if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
int hid = imid / W;
int wid = imid % W;
T tmp = x[(bid * H + hid) * W * C + wid * C + ccid];
T v_y = (tmp - x_bias) * x_scale_inv;
T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid];
d_x[(bid * H + hid) * W * C + wid * C + ccid] =
x_var_inv *
(dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
}
}
template <typename T>
__global__ void ScalarGetDsDbCUDAKernel(
int imsize, const T* x, const T* dy, T* ds, T* db) {
const int nc = blockIdx.x;
T ds_sum = 0;
T db_sum = 0;
for (int i = threadIdx.x; i < imsize; i += blockDim.x) {
const int index = nc * imsize + i;
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1);
}
template <typename T>
__global__ void GetScaleBiasGradientCUDAKernel(int N,
int C,
int group,
T epsilon,
const T* mean,
const T* var,
const T* ds,
const T* db,
T* d_scale,
T* d_bias) {
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < C) {
const int G = group;
const int D = C / G;
T sum1 = 0;
T sum2 = 0;
for (int n = 0; n < N; ++n) {
const int nc = n * C + c;
const int ng = n * G + c / D;
sum1 += (d_scale == nullptr)
? T(0)
: ((ds[nc] - db[nc] * static_cast<T>(mean[ng])) *
static_cast<T>(rsqrt(var[ng] + epsilon)));
sum2 += (d_bias == nullptr) ? T(0) : db[nc];
}
if (d_scale != nullptr) {
d_scale[c] = sum1;
}
if (d_bias != nullptr) {
d_bias[c] = sum2;
}
}
}
template <typename T, int BlockDim>
__global__ void GetBackwardParamsCUDAKernel(int imsize,
int groups,
int group_size,
T epsilon,
const T* mean,
const T* var,
const T* scale,
const T* ds,
const T* db,
T* p1,
T* p2,
T* p3) {
const int n = blockIdx.x;
const int g = blockIdx.y;
const int ng = n * groups + g;
T sum1 = 0;
T sum2 = 0;
T var_inv = rsqrt(var[ng] + epsilon);
for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) {
const int64_t index = ng * group_size + i;
const int64_t c = g * group_size + i;
const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]);
sum1 += ds[index] * scale_v;
sum2 += db[index] * scale_v;
const T scale_c = scale == nullptr ? T(0) : static_cast<T>(scale[c]);
p1[index] = scale_c * var_inv;
}
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
sum1 = BlockReduce(ds_storage).Reduce(sum1, cub::Sum());
sum2 = BlockReduce(db_storage).Reduce(sum2, cub::Sum());
if (threadIdx.x == 0) {
const T s = T(1) / static_cast<T>(group_size * imsize);
const T x = (sum2 * static_cast<T>(mean[ng]) - sum1) *
static_cast<T>(var_inv) * static_cast<T>(var_inv) *
static_cast<T>(var_inv) * s;
p2[ng] = x;
p3[ng] = -x * static_cast<T>(mean[ng]) - sum2 * static_cast<T>(var_inv) * s;
}
}
template <typename T>
__global__ void GetXGradientCUDAKernel(int imsize,
int C,
int group_size,
int groups,
T* p1,
T* p2,
T* p3,
const T* x,
const T* dy,
T* dx) {
int cid = blockIdx.x;
int gid = blockIdx.y;
int bid = blockIdx.z;
int ccid = bid * C + gid * group_size + cid;
int ng = bid * groups + gid;
int nc = gid * group_size + cid;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
int index = (bid * C + nc) * imsize + imid;
dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng];
}
}
template <typename T>
class GroupNormGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* y = ctx.Input<phi::DenseTensor>("Y");
auto* mean = ctx.Input<phi::DenseTensor>("Mean");
auto* var = ctx.Input<phi::DenseTensor>("Variance");
auto* scale = ctx.Input<phi::DenseTensor>("Scale");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto* d_y = ctx.Input<phi::DenseTensor>(framework::GradVarName("Y"));
const auto groups = ctx.Attr<int>("groups");
// init output
auto* d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto* d_scale =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Scale"));
auto* d_bias = ctx.Output<phi::DenseTensor>(framework::GradVarName("Bias"));
const auto& x_dims = x->dims();
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = C / groups;
const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
d_x->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<phi::GPUContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
phi::DenseTensor ds, db;
ds.mutable_data<T>({x_dims[0], C}, ctx.GetPlace());
db.mutable_data<T>({x_dims[0], C}, ctx.GetPlace());
T* ds_data = ds.data<T>();
T* db_data = db.data<T>();
auto* y_data = y->data<T>();
auto* x_data = x->data<T>();
T* d_x_data = nullptr;
if (d_x) d_x_data = d_x->data<T>();
auto* dy_data = d_y->data<T>();
auto* var_data = var->data<T>();
auto* mean_data = mean->data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
d_scale_data = d_scale->data<T>();
}
T* d_bias_data = nullptr;
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
d_bias_data = d_bias->data<T>();
}
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
const int block_dims = 256;
#else
int block_size = std::min(1024, imsize);
const int block_dims = 1024;
#endif
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
if (data_layout == DataLayout::kNCHW) {
const int max_num_threads = 1024;
int max_block_size = std::min(imsize, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 blocks(block_size_nchw);
ScalarGetDsDbCUDAKernel<T>
<<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
if (d_scale || d_bias) {
const int block = 256;
GetScaleBiasGradientCUDAKernel<T>
<<<(C + block - 1) / block, block, 0, dev_ctx.stream()>>>(
x_dims[0],
C,
groups,
epsilon,
mean_data,
var_data,
ds_data,
db_data,
d_scale_data,
d_bias_data);
}
if (d_x_data != nullptr) {
// p1 * dy + p2 * x + p3,
// p1, p2, p3 represent the reverse calculation of temporary variables
// p1 = scale * var_inv
// p2 = (db * scale * mean - ds * scale) * pow(var_inv, 3) * (1/n)
// p3 = -p2 * mean[ng] - db * scale * var_inv * (1/n);
phi::DenseTensor p1, p2, p3;
p1.mutable_data<T>({x_dims[0] * C}, ctx.GetPlace());
p2.mutable_data<T>({x_dims[0], groups}, ctx.GetPlace());
p3.mutable_data<T>({x_dims[0], groups}, ctx.GetPlace());
T* p1_data = p1.data<T>();
T* p2_data = p2.data<T>();
T* p3_data = p3.data<T>();
GetBackwardParamsCUDAKernel<T, block_dims>
<<<dim3(x_dims[0], groups), block_dims, 0, dev_ctx.stream()>>>(
imsize,
groups,
group_size,
epsilon,
mean_data,
var_data,
scale_data,
ds_data,
db_data,
p1_data,
p2_data,
p3_data);
GetXGradientCUDAKernel<T>
<<<grid, threads, 0, dev_ctx.stream()>>>(imsize,
C,
group_size,
groups,
p1_data,
p2_data,
p3_data,
x_data,
dy_data,
d_x_data);
}
} else {
if (d_scale) {
set_zero(dev_ctx, d_scale, static_cast<T>(0));
}
if (d_bias) {
set_zero(dev_ctx, d_bias, static_cast<T>(0));
}
phi::DenseTensor temp_var;
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
T* temp_var_data = temp_var.data<T>();
phi::DenseTensor temp_mean;
temp_mean.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, &temp_mean, static_cast<T>(0));
T* temp_mean_data = temp_mean.data<T>();
int flags = (scale_data != nullptr) * kHasScale +
(bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags,
GroupNormBackwardGetMeanAndVar,
y_data,
scale_data,
bias_data,
dy_data,
x_dims[0],
C,
W,
imsize,
groups,
group_size,
epsilon,
temp_mean_data,
temp_var_data,
d_scale_data,
d_bias_data);
if (d_x_data != nullptr) {
UNROLL_ALL_CASES(flags,
GroupNormBackward,
y_data,
dy_data,
scale_data,
bias_data,
var_data,
temp_mean_data,
temp_var_data,
x_dims[0],
C,
W,
imsize,
groups,
group_size,
epsilon,
d_x_data);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(group_norm,
ops::GroupNormKernel<phi::GPUContext, float>,
ops::GroupNormKernel<phi::GPUContext, double>);
REGISTER_OP_CUDA_KERNEL(group_norm_grad,
ops::GroupNormGradKernel<phi::GPUContext, float>,
ops::GroupNormGradKernel<phi::GPUContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <array>
#include <numeric>
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using DataLayout = phi::DataLayout;
template <typename DeviceContext, typename T>
class GroupNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const float epsilon = ctx.Attr<float>("epsilon");
auto* scale = ctx.Input<phi::DenseTensor>("Scale");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* y = ctx.Output<phi::DenseTensor>("Y");
auto* mean = ctx.Output<phi::DenseTensor>("Mean");
auto* var = ctx.Output<phi::DenseTensor>("Variance");
const auto groups = ctx.Attr<int>("groups");
const auto x_dims = x->dims();
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = C / groups;
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto* x_data = x->data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) {
for (int gid = 0; gid < groups; gid++) {
const int64_t M = 8;
std::array<T, M> x_mean_arr;
std::array<T, M> x_var_arr;
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
T x_mean = 0, x_var = 0;
int number =
std::min(group_size, static_cast<int>(C - gid * group_size));
auto* tmp_x = iter_x_data;
auto* x_src_data = iter_x_data;
auto* tmp_y = iter_y_data;
auto* y_src_data = iter_y_data;
if (data_layout == DataLayout::kNCHW) {
for (int cid = 0; cid < number; cid++) {
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M) {
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
x_mean_arr[0] += iter_x_data[0];
x_var_arr[0] += iter_x_data[0] * iter_x_data[0];
x_mean_arr[1] += iter_x_data[1];
x_var_arr[1] += iter_x_data[1] * iter_x_data[1];
x_mean_arr[2] += iter_x_data[2];
x_var_arr[2] += iter_x_data[2] * iter_x_data[2];
x_mean_arr[3] += iter_x_data[3];
x_var_arr[3] += iter_x_data[3] * iter_x_data[3];
x_mean_arr[4] += iter_x_data[4];
x_var_arr[4] += iter_x_data[4] * iter_x_data[4];
x_mean_arr[5] += iter_x_data[5];
x_var_arr[5] += iter_x_data[5] * iter_x_data[5];
x_mean_arr[6] += iter_x_data[6];
x_var_arr[6] += iter_x_data[6] * iter_x_data[6];
x_mean_arr[7] += iter_x_data[7];
x_var_arr[7] += iter_x_data[7] * iter_x_data[7];
}
x_mean =
std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean);
x_var =
std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var);
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
for (; imid < imsize; imid++, iter_x_data++) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
} else {
for (int cid = 0; cid < number; cid++) {
iter_x_data = tmp_x + cid;
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M * C) {
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
x_mean_arr[0] += iter_x_data[0 * C];
x_var_arr[0] += iter_x_data[0 * C] * iter_x_data[0 * C];
x_mean_arr[1] += iter_x_data[1 * C];
x_var_arr[1] += iter_x_data[1 * C] * iter_x_data[1 * C];
x_mean_arr[2] += iter_x_data[2 * C];
x_var_arr[2] += iter_x_data[2 * C] * iter_x_data[2 * C];
x_mean_arr[3] += iter_x_data[3 * C];
x_var_arr[3] += iter_x_data[3 * C] * iter_x_data[3 * C];
x_mean_arr[4] += iter_x_data[4 * C];
x_var_arr[4] += iter_x_data[4 * C] * iter_x_data[4 * C];
x_mean_arr[5] += iter_x_data[5 * C];
x_var_arr[5] += iter_x_data[5 * C] * iter_x_data[5 * C];
x_mean_arr[6] += iter_x_data[6 * C];
x_var_arr[6] += iter_x_data[6 * C] * iter_x_data[6 * C];
x_mean_arr[7] += iter_x_data[7 * C];
x_var_arr[7] += iter_x_data[7 * C] * iter_x_data[7 * C];
}
x_mean =
std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean);
x_var =
std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var);
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
for (; imid < imsize; imid++, iter_x_data += C) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
iter_x_data = tmp_x + group_size;
}
x_mean /= number * imsize;
x_var /= number * imsize;
x_var = std::max(x_var - x_mean * x_mean, T(0));
T var_inv = T(1) / std::sqrt(x_var + epsilon);
mean_data[bid * groups + gid] = x_mean;
var_data[bid * groups + gid] = x_var;
if (data_layout == DataLayout::kNCHW) {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, tmp_x++, iter_y_data++) {
T val = (tmp_x[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
} else {
for (int cid = 0; cid < number; cid++) {
tmp_x = x_src_data + cid;
iter_y_data = y_src_data + cid;
for (int imid = 0; imid < imsize;
imid++, tmp_x += C, iter_y_data += C) {
T val = (tmp_x[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
iter_y_data = tmp_y + group_size;
}
}
if (data_layout == DataLayout::kNHWC) {
iter_x_data = x_data + (bid + 1) * C * imsize;
iter_y_data = y_data + (bid + 1) * C * imsize;
}
}
}
};
template <typename DeviceContext, typename T>
class GroupNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<phi::DenseTensor>("Y");
auto* var = ctx.Input<phi::DenseTensor>("Variance");
auto* scale = ctx.Input<phi::DenseTensor>("Scale");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto* d_y = ctx.Input<phi::DenseTensor>(framework::GradVarName("Y"));
const auto groups = ctx.Attr<int>("groups");
// init output
auto* d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto* d_scale =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Scale"));
auto* d_bias = ctx.Output<phi::DenseTensor>(framework::GradVarName("Bias"));
const auto& x_dims = x->dims();
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = C / groups;
d_x->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* x_data = x->data<T>();
auto* d_x_data = d_x->data<T>();
auto* y_data = d_y->data<T>();
auto* var_data = var->data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_scale, static_cast<T>(0));
d_scale_data = d_scale->data<T>();
}
T* d_bias_data = nullptr;
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_bias, static_cast<T>(0));
d_bias_data = d_bias->data<T>();
}
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data;
auto* iter_d_x_data = d_x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) {
for (int gid = 0; gid < groups; gid++) {
T x_var = var_data[bid * groups + gid];
T var_inv = 1.0 / sqrt(x_var + epsilon);
int number =
std::min(group_size, static_cast<int>(C - gid * group_size));
T number_inv = 1.0 / (number * imsize);
auto* tmp_x = iter_x_data;
auto* tmp_y = iter_y_data;
auto* tmp_d_x = iter_d_x_data;
auto* x_src_data = iter_x_data;
auto* y_src_data = iter_y_data;
auto* iter_x_data_backup = iter_x_data;
auto* iter_y_data_backup = iter_y_data;
auto* iter_d_x_data_backup = iter_d_x_data;
T dp_scale = 0, dp_bias = 0;
if (data_layout == DataLayout::kNCHW) {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize;
imid++, iter_x_data++, iter_y_data++) {
T val = iter_x_data[0];
if (bias_data) val -= bias_data[gid * group_size + cid];
T dval = iter_y_data[0];
dp_scale += val * dval;
if (scale_data)
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data && scale_data[gid * group_size + cid] != 0)
val /= scale_data[gid * group_size + cid];
if (d_bias_data) d_bias_data[gid * group_size + cid] += dval;
if (d_scale_data)
d_scale_data[gid * group_size + cid] += val * dval;
}
}
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize;
imid++, iter_d_x_data++, tmp_x++, tmp_y++) {
T v_y = tmp_x[0];
T dly = tmp_y[0];
T dss = dp_scale;
T dbs = dp_bias;
T v_scale = 1., v_bias = 0.;
if (scale_data) v_scale = scale_data[gid * group_size + cid];
if (bias_data) v_bias = bias_data[gid * group_size + cid];
v_y -= v_bias;
if (v_scale != 0) v_y /= v_scale;
iter_d_x_data[0] =
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
var_inv;
}
}
} else {
for (int cid = 0; cid < number; cid++) {
iter_x_data = x_src_data + cid;
iter_y_data = y_src_data + cid;
for (int imid = 0; imid < imsize;
imid++, iter_x_data += C, iter_y_data += C) {
T val = iter_x_data[0];
if (bias_data) val -= bias_data[gid * group_size + cid];
T dval = iter_y_data[0];
dp_scale += val * dval;
if (scale_data)
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data && scale_data[gid * group_size + cid] != 0)
val /= scale_data[gid * group_size + cid];
if (d_bias_data) d_bias_data[gid * group_size + cid] += dval;
if (d_scale_data)
d_scale_data[gid * group_size + cid] += val * dval;
}
}
for (int cid = 0; cid < number; cid++) {
tmp_x = x_src_data + cid;
tmp_y = y_src_data + cid;
iter_d_x_data = tmp_d_x + cid;
for (int imid = 0; imid < imsize;
imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) {
T v_y = tmp_x[0];
T dly = tmp_y[0];
T dss = dp_scale;
T dbs = dp_bias;
T v_scale = 1.0, v_bias = 0.;
if (scale_data) v_scale = scale_data[gid * group_size + cid];
if (bias_data) v_bias = bias_data[gid * group_size + cid];
v_y -= v_bias;
if (v_scale != 0) v_y /= v_scale;
iter_d_x_data[0] =
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
var_inv;
}
}
iter_x_data = iter_x_data_backup + group_size;
iter_y_data = iter_y_data_backup + group_size;
iter_d_x_data = iter_d_x_data_backup + group_size;
}
}
if (data_layout == DataLayout::kNHWC) {
iter_x_data = x_data + (bid + 1) * C * imsize;
iter_d_x_data = d_x_data + (bid + 1) * C * imsize;
iter_y_data = y_data + (bid + 1) * C * imsize;
}
}
}
};
} // namespace operators
} // namespace paddle
...@@ -91,10 +91,13 @@ REGISTER_OPERATOR(l1_norm, ...@@ -91,10 +91,13 @@ REGISTER_OPERATOR(l1_norm,
ops::L1NormGradMaker<paddle::framework::OpDesc>, ops::L1NormGradMaker<paddle::framework::OpDesc>,
ops::L1NormGradMaker<paddle::imperative::OpBase>); ops::L1NormGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(l1_norm_grad, ops::L1NormGradOp); REGISTER_OPERATOR(l1_norm_grad, ops::L1NormGradOp);
REGISTER_OP_CPU_KERNEL(l1_norm, ops::L1NormKernel<phi::CPUContext, float>);
REGISTER_OP_CPU_KERNEL(l1_norm_grad,
ops::L1NormGradKernel<phi::CPUContext, float>);
REGISTER_OP_CUDA_KERNEL(l1_norm, ops::L1NormKernel<phi::GPUContext, float>); PD_REGISTER_STRUCT_KERNEL(l1_norm, CPU, ALL_LAYOUT, ops::L1NormKernel, float) {}
REGISTER_OP_CUDA_KERNEL(l1_norm_grad, PD_REGISTER_STRUCT_KERNEL(
ops::L1NormGradKernel<phi::GPUContext, float>); l1_norm_grad, CPU, ALL_LAYOUT, ops::L1NormGradKernel, float) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_STRUCT_KERNEL(l1_norm, GPU, ALL_LAYOUT, ops::L1NormKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
l1_norm_grad, GPU, ALL_LAYOUT, ops::L1NormGradKernel, float) {}
#endif
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
// Out = sum(abs(X)) // Out = sum(abs(X))
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class L1NormKernel : public framework::OpKernel<T> { class L1NormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -39,7 +39,7 @@ class L1NormKernel : public framework::OpKernel<T> { ...@@ -39,7 +39,7 @@ class L1NormKernel : public framework::OpKernel<T> {
}; };
// dX = dout * sign(X) // dX = dout * sign(X)
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class L1NormGradKernel : public framework::OpKernel<T> { class L1NormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册