From 424700ff63a2d8bfaf1e45b7af60bacdf5ddd62f Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Tue, 8 Feb 2022 14:00:56 +0800 Subject: [PATCH] Replace clip, bce_loss, full and full_like with elementwise (#39197) * Replace clip, bce_loss, full and full_like with elementwise --- paddle/fluid/operators/bce_loss_op.cu | 70 ++++++++-------- paddle/fluid/operators/clip_op.h | 10 +++ paddle/fluid/platform/function_traits.h | 17 +++- paddle/pten/kernels/funcs/elementwise_base.h | 25 ++++-- paddle/pten/kernels/gpu/full_kernel.cu | 84 ++++++++++++++++++- paddle/pten/kernels/gpu/scale_kernel.cu | 9 +- .../kernels/primitive/compute_primitives.h | 14 ++++ 7 files changed, 179 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu index 6595d6decc..6ab2e8a6df 100644 --- a/paddle/fluid/operators/bce_loss_op.cu +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -21,40 +21,45 @@ limitations under the License. */ namespace paddle { namespace operators { - -using Tensor = framework::Tensor; - template -struct BCELossGradFunctor { - T one = static_cast(1.0f); - T eps = static_cast(1e-12); - __device__ __forceinline__ T operator()(const T x, const T label, - const T dout) const { - T term1 = max((one - x) * x, eps); - return (dout * (x - label) / term1); - } -}; +struct BCELossFunctor { + T one; + T neg_100; -template -__global__ void GPUBCELossForward(const T* x_data, const T* label_data, - T* out_data, const int in_numel) { - CUDA_KERNEL_LOOP(i, in_numel) { - T x = x_data[i]; - T label = label_data[i]; - T one = static_cast(1.); - T neg_100 = static_cast(-100.); + HOSTDEVICE inline BCELossFunctor() { + one = static_cast(1.0f); + neg_100 = static_cast(-100.); + } + HOSTDEVICE inline T operator()(const T& x, const T& label) const { PADDLE_ENFORCE( (x >= static_cast(0)) && (x <= one), "Input is expected to be within the interval [0, 1], but recieved %f.", x); - T term1 = max(real_log(x), neg_100); T term2 = max(real_log(one - x), neg_100); + return (((label - one) * term2) - (label * term1)); + } +}; + +template +struct BCELossGradFunctor { + T one; + T eps; - out_data[i] = ((label - one) * term2) - (label * term1); + HOSTDEVICE inline BCELossGradFunctor() { + one = static_cast(1.0f); + eps = static_cast(1e-12); } -} + + HOSTDEVICE inline T operator()(const T& x, const T& label, + const T& dout) const { + T term1 = max((one - x) * x, eps); + return (dout * (x - label) / term1); + } +}; + +using Tensor = framework::Tensor; template class BCELossCUDAKernel : public framework::OpKernel { @@ -63,18 +68,13 @@ class BCELossCUDAKernel : public framework::OpKernel { auto* x = ctx.Input("X"); auto* labels = ctx.Input("Label"); auto* out = ctx.Output("Out"); - - const auto* x_data = x->data(); - auto* out_data = out->mutable_data(ctx.GetPlace()); - auto x_numel = x->numel(); - - auto& dev_ctx = ctx.cuda_device_context(); - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(dev_ctx, x_numel); - - GPUBCELossForward<<>>(x_data, labels->data(), - out_data, x_numel); + out->mutable_data(ctx.GetPlace()); + std::vector ins = {x, labels}; + std::vector outs = {out}; + auto& dev_ctx = ctx.template device_context(); + auto functor = BCELossFunctor(); + paddle::operators::LaunchSameDimsElementwiseCudaKernel< + ElementwiseType::kBinary, T, T>(dev_ctx, ins, &outs, functor); } }; diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index 5aff62656f..47bb61a77f 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -172,6 +172,15 @@ class ClipGradKernel : public framework::OpKernel { context.Output(framework::GradVarName("X")); if (d_x != nullptr) { auto* x = context.Input("X"); +#if defined(__NVCC__) || defined(__HIPCC__) + std::vector ins = {d_out, x}; + std::vector outs = {d_x}; + auto functor = ClipGradFunctor(min, max); + d_x->mutable_data(context.GetPlace()); + LaunchSameDimsElementwiseCudaKernel( + context.template device_context(), ins, + &outs, functor); +#else int64_t numel = d_out->numel(); auto* d_x_data = d_x->mutable_data(context.GetPlace()); const T* d_out_data = d_out->data(); @@ -179,6 +188,7 @@ class ClipGradKernel : public framework::OpKernel { Transform trans; trans(context.template device_context(), d_out_data, d_out_data + numel, x_data, d_x_data, ClipGradFunctor(min, max)); +#endif } } }; diff --git a/paddle/fluid/platform/function_traits.h b/paddle/fluid/platform/function_traits.h index e1041184e6..eca78e03e1 100644 --- a/paddle/fluid/platform/function_traits.h +++ b/paddle/fluid/platform/function_traits.h @@ -18,6 +18,18 @@ limitations under the License. */ namespace paddle { namespace platform { +template +struct IsPointerArgs { + static_assert(Arity == sizeof...(Args), "Arity and Args not match!"); + static const bool value = false; +}; + +template +struct IsPointerArgs<1, Args...> { + static_assert(1 == sizeof...(Args), "Arity and Args not match!"); + static const bool value = std::is_pointer< + typename std::tuple_element<0, std::tuple>::type>::value; +}; // Declare a template class with a single template parameter. template @@ -41,10 +53,7 @@ struct FunctionTraits template struct FunctionTraits { static const size_t arity = sizeof...(Args); - static const bool has_pointer_args = - (arity == 1) && - (std::is_pointer< - typename std::tuple_element<0, std::tuple>::type>::value); + static const bool has_pointer_args = IsPointerArgs::value; }; } // namespace platform diff --git a/paddle/pten/kernels/funcs/elementwise_base.h b/paddle/pten/kernels/funcs/elementwise_base.h index d102fd6371..34f2ab4a62 100644 --- a/paddle/pten/kernels/funcs/elementwise_base.h +++ b/paddle/pten/kernels/funcs/elementwise_base.h @@ -31,6 +31,8 @@ namespace kps = pten::kps; #endif +#define BASE_SIZE 1 // To avoid running errors when Arity == 0 in args[Arity] + namespace pten { enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; @@ -475,6 +477,15 @@ struct ElementwisePrimitiveCaller { } }; +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseFillConst(result, func); + } +}; + template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, @@ -548,12 +559,14 @@ template __device__ void VectorizedElementwiseKernelImpl( - const pten::framework::Array &in, + + const pten::framework::Array &in, pten::framework::Array<_ptr_ OutT *, NumOuts> outs, int num, int data_offset, Functor func) { - InT args[Arity][VecSize]; + InT args[Arity + BASE_SIZE][VecSize]; ConditionalT result[VecSize]; #pragma unroll @@ -583,7 +596,8 @@ template __global__ void VectorizedElementwiseKernel( - pten::framework::Array ins, + pten::framework::Array + ins, pten::framework::Array<_ptr_ OutT *, NumOuts> outs, int size, int main_offset, @@ -623,8 +637,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { - auto numel = ins[0]->numel(); - pten::framework::Array ins_data; + auto numel = (*outs)[0]->numel(); + pten::framework::Array + ins_data; pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < Arity; ++i) { diff --git a/paddle/pten/kernels/gpu/full_kernel.cu b/paddle/pten/kernels/gpu/full_kernel.cu index 2f6346daa8..6464dc97d5 100644 --- a/paddle/pten/kernels/gpu/full_kernel.cu +++ b/paddle/pten/kernels/gpu/full_kernel.cu @@ -16,7 +16,89 @@ limitations under the License. */ #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/full_kernel_impl.h" +#include "paddle/pten/kernels/funcs/elementwise_base.h" +namespace pten { + +template +struct FullFuctor { + OutT value; + + template + explicit inline FullFuctor(VType val) { + value = static_cast(val); + } + + __device__ __forceinline__ OutT operator()() const { + return static_cast(value); + } +}; + +template +void FullKernel(const ContextT& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out) { + out->Resize(paddle::framework::make_ddim(shape.GetData())); + int numel = out->numel(); + out->mutable_data(dev_ctx.GetPlace()); + if (numel > 0) { + // in transformer model the numel of outpout will be zero. + std::vector inputs = {}; + std::vector outputs = {out}; + // This function has no input, so the inputs.size() == 0. Use kUnary, but + // the data will not be loaded in the kernel because the number of + // parameters in the operator is 0 + pten::funcs::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, inputs, &outputs, FullFuctor(val.to())); + } +} + +template +void FullLikeKernel(const ContextT& dev_ctx, + const Scalar& val, + DenseTensor* out) { + auto value = val.to(); + using CommonType = typename std::common_type< + float, + typename std::conditional< + std::is_same::value, + float, + T>::type>::type; + + auto common_type_value = static_cast(value); + + PADDLE_ENFORCE_EQ( + (common_type_value >= + static_cast(std::numeric_limits::lowest())) && + (common_type_value <= + static_cast(std::numeric_limits::max())), + true, + paddle::platform::errors::InvalidArgument( + "The filled value is out of range for target type, " + "current kernel type is %s, the range should between %f " + "and %f, but now value is %f.", + typeid(T).name(), + static_cast(std::numeric_limits::lowest()), + static_cast(std::numeric_limits::max()), + static_cast(value))); + std::vector inputs = {}; + std::vector outputs = {out}; + out->mutable_data(dev_ctx.GetPlace()); + // This function has no input, so the inputs.size() == 0. Use kUnary, but the + // data will not be loaded in the kernel because the number of parameters in + // the operator is 0 + int numel = out->numel(); + if (numel > 0) { + pten::funcs::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, inputs, &outputs, FullFuctor(value)); + } +} + +} // namespace pten PT_REGISTER_KERNEL(full, GPU, diff --git a/paddle/pten/kernels/gpu/scale_kernel.cu b/paddle/pten/kernels/gpu/scale_kernel.cu index 6cf84acd9d..5add34a230 100644 --- a/paddle/pten/kernels/gpu/scale_kernel.cu +++ b/paddle/pten/kernels/gpu/scale_kernel.cu @@ -28,11 +28,10 @@ struct ScaleFunctor { InT scale; bool bias_after_scale; - ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) { - scale = scale_data; - bias = bias_data; - bias_after_scale = is_bias_after_sacle; - } + ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) + : bias(bias_data), + scale(scale_data), + bias_after_scale(is_bias_after_sacle) {} __device__ __forceinline__ InT operator()(const InT x) const { if (bias_after_scale) { diff --git a/paddle/pten/kernels/primitive/compute_primitives.h b/paddle/pten/kernels/primitive/compute_primitives.h index ac812c9c9f..f854cf95aa 100644 --- a/paddle/pten/kernels/primitive/compute_primitives.h +++ b/paddle/pten/kernels/primitive/compute_primitives.h @@ -414,5 +414,19 @@ __device__ __forceinline__ void Reduce(T* out, } } +template +__device__ __forceinline__ void ElementwiseFillConst(OutT* out, + OpFunc compute) { +#pragma unroll + for (int idx = 0; idx < NX * NY; idx++) { + out[idx] = static_cast(compute()); + } +} + } // namespace kps } // namespace pten -- GitLab