未验证 提交 ece200b3 编写于 作者: H hong 提交者: GitHub

Move norm to pten (#39324)

* add norm cpu

* update code;

* norm bug fix

* move norm op to pten; test=develop

* move norm op to pten; test=develop

* add norm util; test=develop

* fix norm npu bug; test=develop

* fix norm kernel bug; test=develop

* move kernel args to pten; test=develop

* move kernel args to pten sig; test=develop
上级 3990e0bb
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/norm_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -115,7 +115,3 @@ REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker, ...@@ -115,7 +115,3 @@ REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker,
ops::NormOpGradOpMaker<paddle::framework::OpDesc>, ops::NormOpGradOpMaker<paddle::framework::OpDesc>,
ops::NormOpGradOpMaker<paddle::imperative::OpBase>); ops::NormOpGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(norm_grad, ops::NormOpGrad); REGISTER_OPERATOR(norm_grad, ops::NormOpGrad);
REGISTER_OP_CPU_KERNEL(norm, ops::NormKernel<CPU, float>,
ops::NormKernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(norm_grad, ops::NormGradKernel<CPU, float>,
ops::NormGradKernel<CPU, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/norm_op.h"
#include "paddle/fluid/platform/bfloat16.h"
namespace paddle {
namespace operators {
__device__ __forceinline__ platform::float16 square_root(platform::float16 x) {
return static_cast<platform::float16>(sqrtf(static_cast<float>(x)));
}
__device__ __forceinline__ float square_root(float x) { return sqrtf(x); }
__device__ __forceinline__ double square_root(double x) { return sqrt(x); }
template <typename T, int BlockDim>
__global__ void Normalize(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, const T eps, T* y, T* out_norm) {
using MT = typename details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
MT sum = 0.0;
__shared__ MT norm;
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += x_ij * x_ij;
}
MT reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) {
norm = square_root(reduce_result + static_cast<MT>(eps));
out_norm[i] = static_cast<T>(norm);
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const int index = base + j * post;
y[index] = static_cast<T>((static_cast<MT>(x[index]) / norm));
}
}
}
template <typename DeviceContext, typename T>
class NormCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* out_y = ctx.Output<framework::Tensor>("Out");
auto xdim = in_x->dims();
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
bool is_test = ctx.Attr<bool>("is_test");
framework::Tensor* out_norm;
framework::Tensor out_norm_tmp;
if (is_test) {
auto out_dim = in_x->dims();
out_dim[axis] = 1;
out_norm = &out_norm_tmp;
out_norm->Resize(out_dim);
} else {
out_norm = ctx.Output<framework::Tensor>("Norm");
}
const T* x = in_x->data<T>();
T* y = out_y->mutable_data<T>(ctx.GetPlace());
T* norm = out_norm->mutable_data<T>(ctx.GetPlace());
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
auto& dev_ctx = ctx.cuda_device_context();
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
Normalize<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
eps, y, norm);
}
};
template <typename T, int BlockDim>
__global__ void NormalizeGradient(const T* x, const T* x_norm, const T* y_grad,
const int pre, const int axis_n,
const int post, T* x_grad) {
using MT = typename details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
MT sum = 0.0;
__shared__ MT row_sum;
__shared__ MT row_sqrt_norm;
__shared__ MT row_norm;
auto base = (i / post) * post * axis_n + (i % post);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
sum += static_cast<MT>(x[index]) * static_cast<MT>(y_grad[index]);
}
MT reduce_result = BlockReduce(temp_storage_sum).Sum(sum);
if (threadIdx.x == 0) {
row_sum = reduce_result;
row_sqrt_norm = static_cast<MT>(x_norm[i]);
row_norm = row_sqrt_norm * row_sqrt_norm;
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const MT x_ij = static_cast<MT>(x[index]);
const MT dy_ij = static_cast<MT>(y_grad[index]);
x_grad[index] =
static_cast<T>((dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm);
}
}
}
template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* in_norm = ctx.Input<framework::Tensor>("Norm");
auto* in_dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* dx = out_dx->mutable_data<T>(ctx.GetPlace());
const T* x = in_x->data<T>();
const T* x_norm = in_norm->data<T>();
const T* dy = in_dy->data<T>();
auto xdim = in_x->dims();
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
auto& dev_ctx = ctx.cuda_device_context();
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
NormalizeGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
x, x_norm, dy, pre, n, post, dx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(norm,
ops::NormCUDAKernel<CUDA, paddle::platform::float16>,
ops::NormCUDAKernel<CUDA, float>,
ops::NormCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(
norm_grad, ops::NormGradCUDAKernel<CUDA, paddle::platform::float16>,
ops::NormGradCUDAKernel<CUDA, float>,
ops::NormGradCUDAKernel<CUDA, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
template <typename DeviceContext, typename T>
class NormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* out_y = ctx.Output<framework::Tensor>("Out");
auto xdim = in_x->dims();
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
bool is_test = ctx.Attr<bool>("is_test");
framework::Tensor* out_norm;
framework::Tensor out_norm_tmp;
if (is_test) {
auto out_dim = in_x->dims();
out_dim[axis] = 1;
out_norm = &out_norm_tmp;
out_norm->Resize(out_dim);
} else {
out_norm = ctx.Output<framework::Tensor>("Norm");
}
out_y->mutable_data<T>(ctx.GetPlace());
out_norm->mutable_data<T>(ctx.GetPlace());
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 2> norm_shape(pre, post);
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
auto y_e = framework::EigenVector<T>::Flatten(*out_y);
auto norm_e = framework::EigenVector<T>::Flatten(*out_norm);
auto x = x_e.reshape(shape);
auto y = y_e.reshape(shape);
auto norm = norm_e.reshape(norm_shape);
Eigen::DSizes<int, 1> rdim(1);
// y = x / sqrt((sum(x * x) + epsilon))
// norm = sqrt(sum(x * x) + epsilon)
auto x2 = x * x;
auto sum = x2.sum(rdim) + eps;
norm.device(*place) = sum.sqrt();
// y = x / norm
Eigen::DSizes<int, 3> rshape(pre, 1, post);
Eigen::DSizes<int, 3> bcast(1, n, 1);
y.device(*place) = x / norm.reshape(rshape).broadcast(bcast);
}
};
template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* in_norm = ctx.Input<framework::Tensor>("Norm");
auto* in_dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
out_dx->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims();
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
auto dy_e = framework::EigenVector<T>::Flatten(*in_dy);
auto norm_e = framework::EigenVector<T>::Flatten(*in_norm);
auto dx_e = framework::EigenVector<T>::Flatten(*out_dx);
Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 3> rshape(pre, 1, post);
auto x = x_e.reshape(shape);
auto dy = dy_e.reshape(shape);
auto norm = norm_e.reshape(rshape);
auto dx = dx_e.reshape(shape);
framework::Tensor rsum;
rsum.mutable_data<T>({pre, post}, ctx.GetPlace());
auto sum = framework::EigenTensor<T, 2>::From(rsum);
Eigen::DSizes<int, 1> rdim(1);
Eigen::DSizes<int, 3> bcast(1, n, 1);
// dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
// = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
// = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x))
// 1. sum = sum(x*dy)
sum.device(*place) = (x * dy).sum(rdim);
// 2. dx = x * sum
dx.device(*place) = sum.reshape(rshape).broadcast(bcast) * x;
// 3. dx / (sum(x*x) + e)
// where, norm.pow(2) = sum(x*x) + e, which is calculated in forward.
dx.device(*place) = dx / norm.pow(2).broadcast(bcast);
// 4. [dy - dx] / sqrt(sum(x*x))
dx.device(*place) = (dy - dx) / norm.broadcast(bcast);
}
};
} // namespace operators
} // namespace paddle
...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/norm_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
......
...@@ -24,6 +24,18 @@ limitations under the License. */ ...@@ -24,6 +24,18 @@ limitations under the License. */
namespace pten { namespace pten {
constexpr char kGradVarSuffix[] = "@GRAD";
constexpr size_t kGradVarSuffixSize = 5U;
inline std::string GradVarName(const std::string& var_name) {
std::string result;
result.reserve(var_name.size() + kGradVarSuffixSize);
result += var_name;
result += kGradVarSuffix;
return result;
}
// tuple(input_names, attr_names, output_names) // tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>, using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>, paddle::SmallVector<std::string>,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/norm_grad_kernel.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/common_shape.h"
namespace pten {
template <typename T, typename Context>
void NormGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& norm,
int axis,
float epsilon,
bool is_test,
DenseTensor* x_grad) {
auto* in_x = &x;
auto* in_dy = &out_grad;
auto* in_norm = &norm;
auto* out_dx = x_grad;
ctx.template Alloc<T>(out_dx);
auto xdim = in_x->dims();
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post);
auto* place = ctx.eigen_device();
auto x_e = paddle::framework::EigenVector<T>::Flatten(*in_x);
auto dy_e = paddle::framework::EigenVector<T>::Flatten(*in_dy);
auto norm_e = paddle::framework::EigenVector<T>::Flatten(*in_norm);
auto dx_e = paddle::framework::EigenVector<T>::Flatten(*out_dx);
Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 3> rshape(pre, 1, post);
auto x_r = x_e.reshape(shape);
auto dy = dy_e.reshape(shape);
auto norm_r = norm_e.reshape(rshape);
auto dx = dx_e.reshape(shape);
DenseTensor rsum;
rsum.Resize({pre, post});
ctx.template Alloc<T>(&rsum);
auto sum = paddle::framework::EigenTensor<T, 2>::From(rsum);
Eigen::DSizes<int, 1> rdim(1);
Eigen::DSizes<int, 3> bcast(1, n, 1);
// dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
// = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
// = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x))
// 1. sum = sum(x*dy)
sum.device(*place) = (x_r * dy).sum(rdim);
// 2. dx = x * sum
dx.device(*place) = sum.reshape(rshape).broadcast(bcast) * x_r;
// 3. dx / (sum(x*x) + e)
// where, norm.pow(2) = sum(x*x) + e, which is calculated in forward.
dx.device(*place) = dx / norm_r.pow(2).broadcast(bcast);
// 4. [dy - dx] / sqrt(sum(x*x))
dx.device(*place) = (dy - dx) / norm_r.broadcast(bcast);
}
} // namespace pten
PT_REGISTER_KERNEL(
norm_grad, CPU, ALL_LAYOUT, pten::NormGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/norm_kernel.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/common_shape.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
namespace pten {
template <typename T, typename Context>
void NormKernel(const Context& ctx,
const DenseTensor& x,
int axis,
float epsilon,
bool is_test,
DenseTensor* out,
DenseTensor* norm) {
auto xdim = x.dims();
T eps = epsilon;
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post);
DenseTensor* out_norm;
DenseTensor out_norm_tmp;
if (is_test) {
auto out_dim = x.dims();
out_dim[axis] = 1;
out_norm = &out_norm_tmp;
out_norm->Resize(out_dim);
} else {
out_norm = norm;
}
ctx.template Alloc<T>(out);
ctx.template Alloc<T>(out_norm);
auto* place = ctx.eigen_device();
Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 2> norm_shape(pre, post);
auto x_e = paddle::framework::EigenVector<T>::Flatten(x);
auto y_e = paddle::framework::EigenVector<T>::Flatten(*out);
auto norm_e = paddle::framework::EigenVector<T>::Flatten(*out_norm);
auto x_r = x_e.reshape(shape);
auto y = y_e.reshape(shape);
auto norm_reshape = norm_e.reshape(norm_shape);
Eigen::DSizes<int, 1> rdim(1);
// y = x / sqrt((sum(x * x) + epsilon))
// norm = sqrt(sum(x * x) + epsilon)
auto x2 = x_r * x_r;
auto sum = x2.sum(rdim) + eps;
norm_reshape.device(*place) = sum.sqrt();
// y = x / norm
Eigen::DSizes<int, 3> rshape(pre, 1, post);
Eigen::DSizes<int, 3> bcast(1, n, 1);
y.device(*place) = x_r / norm_reshape.reshape(rshape).broadcast(bcast);
}
} // namespace pten
PT_REGISTER_KERNEL(norm, CPU, ALL_LAYOUT, pten::NormKernel, float, double) {}
...@@ -89,5 +89,18 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, ...@@ -89,5 +89,18 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
} }
} }
inline void GetPrePostNumel(
const framework::DDim &dim, int axis, int *pre, int *n, int *post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
} // namespace funcs } // namespace funcs
} // namespace pten } // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/pten/kernels/norm_grad_kernel.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/pten/common/bfloat16.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/common_shape.h"
namespace pten {
template <typename T, int BlockDim>
__global__ void NormalizeGradient(const T* x,
const T* x_norm,
const T* y_grad,
const int pre,
const int axis_n,
const int post,
T* x_grad) {
using MT = typename paddle::operators::details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
MT sum = 0.0;
__shared__ MT row_sum;
__shared__ MT row_sqrt_norm;
__shared__ MT row_norm;
auto base = (i / post) * post * axis_n + (i % post);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
sum += static_cast<MT>(x[index]) * static_cast<MT>(y_grad[index]);
}
MT reduce_result = BlockReduce(temp_storage_sum).Sum(sum);
if (threadIdx.x == 0) {
row_sum = reduce_result;
row_sqrt_norm = static_cast<MT>(x_norm[i]);
row_norm = row_sqrt_norm * row_sqrt_norm;
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const MT x_ij = static_cast<MT>(x[index]);
const MT dy_ij = static_cast<MT>(y_grad[index]);
x_grad[index] =
static_cast<T>((dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm);
}
}
}
template <typename T, typename Context>
void NormGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& norm,
int axis,
float epsilon,
bool is_test,
DenseTensor* x_grad) {
auto* in_x = &x;
auto* in_norm = &norm;
auto* in_dy = &out_grad;
auto* out_dx = x_grad;
ctx.template Alloc<T>(out_dx);
T* dx = out_dx->data<T>();
const T* x_data = in_x->data<T>();
const T* x_norm = in_norm->data<T>();
const T* dy = in_dy->data<T>();
auto xdim = in_x->dims();
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post);
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
NormalizeGradient<T, block><<<grid, block, 0, ctx.stream()>>>(
x_data, x_norm, dy, pre, n, post, dx);
}
} // namespace pten
PT_REGISTER_KERNEL(norm_grad,
GPU,
ALL_LAYOUT,
pten::NormGradKernel,
float,
double,
paddle::platform::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/pten/kernels/norm_kernel.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/pten/common/float16.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/common_shape.h"
namespace pten {
__device__ __forceinline__ dtype::float16 square_root(dtype::float16 x) {
return static_cast<dtype::float16>(sqrtf(static_cast<float>(x)));
}
__device__ __forceinline__ float square_root(float x) { return sqrtf(x); }
__device__ __forceinline__ double square_root(double x) { return sqrt(x); }
template <typename T, int BlockDim>
__global__ void Normalize(const T* x,
const int pre,
const int axis_n, // dim in axis
const int post,
const T eps,
T* y,
T* out_norm) {
using MT = typename paddle::operators::details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
MT sum = 0.0;
__shared__ MT norm;
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += x_ij * x_ij;
}
MT reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) {
norm = square_root(reduce_result + static_cast<MT>(eps));
out_norm[i] = static_cast<T>(norm);
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const int index = base + j * post;
y[index] = static_cast<T>((static_cast<MT>(x[index]) / norm));
}
}
}
template <typename T, typename Context>
void NormKernel(const Context& ctx,
const DenseTensor& x,
int axis,
float epsilon,
bool is_test,
DenseTensor* out,
DenseTensor* norm) {
auto* in_x = &x;
auto* out_y = out;
auto xdim = in_x->dims();
if (axis < 0) axis = xdim.size() + axis;
T eps = static_cast<T>(epsilon);
DenseTensor* out_norm;
DenseTensor out_norm_tmp;
if (is_test) {
auto out_dim = in_x->dims();
out_dim[axis] = 1;
out_norm = &out_norm_tmp;
out_norm->Resize(out_dim);
} else {
out_norm = norm;
}
const T* x_ptr = in_x->data<T>();
ctx.template Alloc<T>(out_y);
ctx.template Alloc<T>(out_norm);
T* y = out_y->data<T>();
T* norm_ptr = out_norm->data<T>();
int pre, n, post;
funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post);
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
Normalize<T, block><<<grid, block, 0, ctx.stream()>>>(
x_ptr, pre, n, post, eps, y, norm_ptr);
}
} // namespace pten
PT_REGISTER_KERNEL(norm,
GPU,
ALL_LAYOUT,
pten::NormKernel,
float,
double,
paddle::platform::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void NormGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& out,
int axis,
float epsilon,
bool is_test,
DenseTensor* x_grad);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void NormKernel(const Context& ctx,
const DenseTensor& x,
int axis,
float epsilon,
bool is_test,
DenseTensor* out,
DenseTensor* norm);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature NormOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"norm", {"X"}, {"axis", "epsilon", "is_test"}, {"Out", "Norm"});
}
KernelSignature NormGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("norm_grad",
{GradVarName("Out"), "X", "Norm"},
{"axis", "epsilon", "is_test"},
{GradVarName("X")});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(norm, pten::NormOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(norm_grad, pten::NormGradOpArgumentMapping);
...@@ -154,4 +154,6 @@ class API_NormTest(unittest.TestCase): ...@@ -154,4 +154,6 @@ class API_NormTest(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册