未验证 提交 10d9ab4b 编写于 作者: N Noel 提交者: GitHub

[pnorm] Optimize p_norm op for special cases (#37685)

上级 3a339cc0
......@@ -21,7 +21,10 @@ limitations under the License. */
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -56,87 +59,94 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent);
}
template <typename T, int BlockDim>
__global__ void Pnorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, float porder, T* out_norm) {
using MT = typename details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
auto porder_t = static_cast<MT>(porder);
auto porder_inv = static_cast<MT>(1.0 / porder);
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
MT sum = static_cast<MT>(0.0);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += inline_pow(inline_abs(x_ij), porder_t);
}
MT reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0)
out_norm[i] = static_cast<T>(inline_pow(reduce_result, porder_inv));
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(x);
}
}
};
template <typename T, int BlockDim>
__global__ void ZeorNorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, T* out_norm) {
using MT = typename details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
MT sum = static_cast<MT>(0.0);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += static_cast<MT>(static_cast<double>(x_ij) != 0);
}
MT reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) out_norm[i] = static_cast<T>(reduce_result);
struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() {}
HOSTDEVICE explicit inline NonzeroFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(static_cast<double>(x) != 0);
}
}
};
template <typename T, int BlockDim>
__global__ void InfNorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T cur_max = inline_abs(x[base]);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
T x_ij_abs = inline_abs(x[base + j * post]);
if (cur_max < x_ij_abs) cur_max = x_ij_abs;
}
T reduce_result = BlockReduce(temp_storage).Reduce(cur_max, cub::Max());
if (threadIdx.x == 0) out_norm[i] = reduce_result;
struct AbsFunctor {
HOSTDEVICE explicit inline AbsFunctor() {}
HOSTDEVICE explicit inline AbsFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(inline_abs(x));
}
}
};
template <typename T, int BlockDim>
__global__ void NegInfNorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T cur_min = inline_abs(x[base]);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
T x_ij_abs = inline_abs(x[base + j * post]);
if (cur_min > x_ij_abs) cur_min = x_ij_abs;
}
T reduce_result = BlockReduce(temp_storage).Reduce(cur_min, cub::Min());
if (threadIdx.x == 0) out_norm[i] = reduce_result;
template <typename Tx, typename Ty = Tx>
struct UnsignedPowFunctor {
HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) {
this->porder = porder;
}
}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(inline_pow(inline_abs(x), static_cast<Tx>(porder)));
}
float porder;
};
template <typename Tx, typename Ty = Tx>
struct PowFunctor {
HOSTDEVICE explicit inline PowFunctor(float porder) { this->porder = porder; }
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(inline_pow(x, static_cast<Tx>(porder)));
}
float porder;
};
template <typename Tx, typename Ty = Tx>
struct AbsAndMin {
using Transformer = AbsFunctor;
using MT = typename details::MPTypeTrait<Ty>::Type;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<MT>::infinity());
}
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return (a < b) ? a : b;
}
};
template <typename Tx, typename Ty = Tx>
struct AbsAndMax {
using Transformer = AbsFunctor;
using MT = typename details::MPTypeTrait<Ty>::Type;
inline Ty initial() {
return static_cast<Ty>(-std::numeric_limits<MT>::infinity());
}
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return (a > b) ? a : b;
}
};
template <typename Tx, typename Ty = Tx>
struct NonzeroAndSum {
using Transformer = NonzeroFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct IdentityAndSum {
using Transformer = IdentityFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename DeviceContext, typename T>
class PnormCUDAKernel : public framework::OpKernel<T> {
......@@ -146,101 +156,83 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
auto* out_norm = ctx.Output<framework::Tensor>("Out");
const T* x = in_x->data<T>();
T* norm = out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims();
auto ndim = out_norm->dims();
float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post, asvector);
auto& dev_ctx = ctx.cuda_device_context();
std::vector<int> reduce_axis = {axis};
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512;
#endif
auto stream = ctx.cuda_device_context().stream();
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
using MT = typename details::MPTypeTrait<T>::Type;
if (porder == 0) {
ZeorNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
norm);
TensorReduceFunctorImpl<T, T, NonzeroAndSum>(*in_x, out_norm, reduce_axis,
stream);
} else if (porder == INFINITY) {
InfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
norm);
TensorReduceFunctorImpl<T, T, AbsAndMax>(*in_x, out_norm, reduce_axis,
stream);
} else if (porder == -INFINITY) {
NegInfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n,
post, norm);
TensorReduceFunctorImpl<T, T, AbsAndMin>(*in_x, out_norm, reduce_axis,
stream);
} else {
Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
porder, norm);
framework::Tensor tmp_x;
tmp_x.mutable_data<T>(xdim, ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {in_x};
std::vector<framework::Tensor*> outs = {&tmp_x};
auto func = UnsignedPowFunctor<MT, T>(porder);
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, MT, T,
UnsignedPowFunctor<MT, T>>(
cuda_ctx, ins, &outs, func);
framework::Tensor tmp_y;
tmp_y.mutable_data<T>(ndim, ctx.GetPlace());
TensorReduceFunctorImpl<T, T, IdentityAndSum>(tmp_x, &tmp_y, reduce_axis,
stream);
const framework::Tensor* tmp_norm = &tmp_y;
ins = {tmp_norm};
outs = {out_norm};
auto func_inverse = UnsignedPowFunctor<MT, T>(1. / porder);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, MT, T,
UnsignedPowFunctor<MT, T>>(
cuda_ctx, ins, &outs, func_inverse);
}
}
};
template <typename T, int BlockDim>
__global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
const float porder, const int pre,
const int axis_n, const int post, const T eps,
T* x_grad) {
using MT = typename details::MPTypeTrait<T>::Type;
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
int num = pre * post;
auto porder_grad = static_cast<MT>(porder - 1.0f);
for (int i = blockIdx.x; i < num; i += gridDim.x) {
__shared__ MT pnorm_i;
__shared__ MT yout_i;
auto base = (i / post) * post * axis_n + (i % post);
if (threadIdx.x == 0) {
pnorm_i = static_cast<MT>(x_norm[i]);
yout_i = static_cast<MT>(y_grad[i]);
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const MT x_ij = static_cast<MT>(inline_abs(x[index]));
x_grad[index] = static_cast<T>(
inline_pow(x_ij, porder_grad) /
(inline_pow(pnorm_i, porder_grad) + static_cast<MT>(eps)) * yout_i *
static_cast<MT>(inline_sign(x[index])));
}
template <typename T>
struct AbsMaxAndMinGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
auto equals = ((*x).abs() == y->broadcast(dim));
auto ones = dx->constant(static_cast<T>(1.));
auto negs = dx->constant(static_cast<T>(-1.));
auto zeros = dx->constant(static_cast<T>(0.));
auto positives = (*x) > zeros;
dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) *
positives.select(ones, negs);
}
}
};
template <typename T, int BlockDim>
__global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad,
const int pre, const int axis_n, const int post,
T* x_grad) {
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
__shared__ T pnorm_i;
__shared__ T yout_i;
auto base = (i / post) * post * axis_n + (i % post);
if (threadIdx.x == 0) {
pnorm_i = x_norm[i];
yout_i = y_grad[i];
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const T x_ij = inline_abs(x[index]);
if (x_ij == pnorm_i) {
x_grad[index] = static_cast<T>(inline_sign(x[index])) * yout_i;
} else {
x_grad[index] = static_cast<T>(0);
}
}
template <typename T>
struct PNormPostGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
auto ones = dx->constant(static_cast<T>(1.));
auto negs = dx->constant(static_cast<T>(-1.));
auto zeros = dx->constant(static_cast<T>(0.));
auto positives = (*x) > zeros;
dx->device(place) = (*dx) * dy->broadcast(dim) * y->broadcast(dim) *
positives.select(ones, negs);
}
}
};
template <typename DeviceContext, typename T, typename AttrType = T>
class PnormGradCUDAKernel : public framework::OpKernel<T> {
......@@ -252,40 +244,40 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* dx = out_dx->mutable_data<T>(ctx.GetPlace());
const T* x = in_x->data<T>();
const T* x_norm = in_norm->data<T>();
const T* norm_dy = in_norm_dy->data<T>();
auto xdim = in_x->dims();
float porder = ctx.Attr<float>("porder");
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
bool reduce_all = ((axis < 0) || (in_norm->numel() == 1));
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post, asvector);
const std::vector<int> dims = {axis};
auto& dev_ctx = ctx.cuda_device_context();
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512;
#endif
auto& cuda_ctx = ctx.template device_context<DeviceContext>();
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
if (porder == 0) {
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
set_zero(dev_ctx, out_dx, static_cast<T>(0));
set_zero(cuda_ctx, out_dx, static_cast<T>(0));
} else if (porder == INFINITY || porder == -INFINITY) {
InfNormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
x, x_norm, norm_dy, pre, n, post, dx);
LaunchReduceGradKernel<DeviceContext, T, AbsMaxAndMinGradFunctor<T>>(
ctx, in_x, in_norm, in_norm_dy, out_dx, dims, reduce_all);
} else {
PnormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
x, x_norm, norm_dy, porder, pre, n, post, eps, dx);
framework::Tensor tmp_norm;
tmp_norm.mutable_data<T>(in_norm->dims(), ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {in_norm};
std::vector<framework::Tensor*> outs = {&tmp_norm};
auto pow_functor = PowFunctor<T>(1. - porder);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T,
PowFunctor<T>>(cuda_ctx, ins, &outs,
pow_functor);
ins = {in_x};
outs = {out_dx};
auto unsigned_pow = UnsignedPowFunctor<T>(porder - 1.);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T,
UnsignedPowFunctor<T>>(
cuda_ctx, ins, &outs, unsigned_pow);
const framework::Tensor* tmp_norm_const = &tmp_norm;
LaunchReduceGradKernel<DeviceContext, T, PNormPostGradFunctor<T>>(
ctx, in_x, tmp_norm_const, in_norm_dy, out_dx, dims, reduce_all);
}
}
};
......
......@@ -326,6 +326,67 @@ class BoolReduceKernel : public framework::OpKernel<OutT> {
}
};
template <typename DeviceContext, typename T, typename Functor>
void LaunchReduceGradKernel(const framework::ExecutionContext& context,
const framework::Tensor* input0,
const framework::Tensor* input1,
const framework::Tensor* input2,
paddle::framework::Tensor* output,
const std::vector<int>& dims,
bool reduce_all = false) {
if (reduce_all) {
auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::Flatten(*input1);
auto x_reduce_grad = EigenVector<T>::Flatten(*input2);
auto x_grad = EigenVector<T>::Flatten(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto broadcast_dim =
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
Functor functor;
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
broadcast_dim[0]);
} else {
int rank = input0->dims().size();
switch (rank) {
case 1:
ReduceGradFunctor<DeviceContext, T, 1, Functor>(
context.template device_context<DeviceContext>(), *input0, *input1,
*input2, output, dims);
break;
case 2:
ReduceGradFunctor<DeviceContext, T, 2, Functor>(
context.template device_context<DeviceContext>(), *input0, *input1,
*input2, output, dims);
break;
case 3:
ReduceGradFunctor<DeviceContext, T, 3, Functor>(
context.template device_context<DeviceContext>(), *input0, *input1,
*input2, output, dims);
break;
case 4:
ReduceGradFunctor<DeviceContext, T, 4, Functor>(
context.template device_context<DeviceContext>(), *input0, *input1,
*input2, output, dims);
break;
case 5:
ReduceGradFunctor<DeviceContext, T, 5, Functor>(
context.template device_context<DeviceContext>(), *input0, *input1,
*input2, output, dims);
break;
case 6:
ReduceGradFunctor<DeviceContext, T, 6, Functor>(
context.template device_context<DeviceContext>(), *input0, *input1,
*input2, output, dims);
break;
default:
HandleLargeDimGrad<DeviceContext, T, Functor>(context, input0, input1,
input2, output, dims);
break;
}
}
}
template <typename DeviceContext, typename T, typename Functor,
bool kNoNeedBufferX = false, bool kNoNeedBufferY = false>
class ReduceGradKernel : public framework::OpKernel<T> {
......@@ -362,61 +423,13 @@ class ReduceGradKernel : public framework::OpKernel<T> {
input1 = input2;
}
const std::vector<int> const_dims = dims;
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// not be set as Input in grad Maker, use Out_grad to replace here
if (!input1) input1 = input2;
if (reduce_all) {
auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::Flatten(*input1);
auto x_reduce_grad = EigenVector<T>::Flatten(*input2);
auto x_grad = EigenVector<T>::Flatten(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto broadcast_dim =
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
Functor functor;
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
broadcast_dim[0]);
} else {
int rank = input0->dims().size();
switch (rank) {
case 1:
ReduceGradFunctor<DeviceContext, T, 1, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 2:
ReduceGradFunctor<DeviceContext, T, 2, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 3:
ReduceGradFunctor<DeviceContext, T, 3, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 4:
ReduceGradFunctor<DeviceContext, T, 4, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 5:
ReduceGradFunctor<DeviceContext, T, 5, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
case 6:
ReduceGradFunctor<DeviceContext, T, 6, Functor>(
context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims);
break;
default:
HandleLargeDimGrad<DeviceContext, T, Functor>(context, input0, input1,
input2, output, dims);
break;
}
}
LaunchReduceGradKernel<DeviceContext, T, Functor>(
context, input0, input1, input2, output, const_dims, reduce_all);
}
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -186,7 +186,6 @@ register_unity_group(cc
norm_op.cc
one_hot_op.cc
one_hot_v2_op.cc
p_norm_op.cc
pad2d_op.cc
pad3d_op.cc
pad_constant_like_op.cc
......@@ -468,7 +467,6 @@ register_unity_group(cu
nll_loss_op.cu
norm_op.cu
one_hot_op.cu
p_norm_op.cu
pad2d_op.cu
pad3d_op.cu
pad_constant_like_op.cu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册