From b007a031053760796d145a8fa010e0a0dd420d53 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Wed, 9 Feb 2022 10:58:22 +0800 Subject: [PATCH] Delete BASE_SIZE in elementwise_base.h (#39390) --- paddle/pten/core/utils/array.h | 14 ++++---------- paddle/pten/kernels/funcs/elementwise_base.h | 18 +++++++----------- paddle/pten/kernels/gpu/full_kernel.cu | 9 ++++----- .../kernels/primitive/compute_primitives.h | 3 +-- .../primitive/compute_primitives_xpu2.h | 13 +++++++++++++ 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/paddle/pten/core/utils/array.h b/paddle/pten/core/utils/array.h index cd43dc7b420..2d6bfbe213b 100644 --- a/paddle/pten/core/utils/array.h +++ b/paddle/pten/core/utils/array.h @@ -104,28 +104,22 @@ class Array { HOSTDEVICE inline T *GetMutable() { return nullptr; } HOSTDEVICE inline T &operator[](size_t) { -#if defined(__HIPCC__) - // HIP will have compile error, if use "obj()" +#if defined(__HIPCC__) || defined(__CUDA_ARCH__) + // HIP and CUDA will have compile error, if use "obj()" // function declared in block scope cannot have 'static' storage class static T obj{}; return obj; -#elif defined(__CUDA_ARCH__) - static T obj(); - return obj; #else PADDLE_THROW(pten::errors::Unavailable("Array has no element.")); #endif } HOSTDEVICE inline const T &operator[](size_t) const { -#if defined(__HIPCC__) - // HIP will have compile error, if use "obj()" +#if defined(__HIPCC__) || defined(__CUDA_ARCH__) + // HIP and CUDA will have compile error, if use "obj()" // function declared in block scope cannot have 'static' storage class static const T obj{}; return obj; -#elif defined(__CUDA_ARCH__) - static const T obj(); - return obj; #else PADDLE_THROW(pten::errors::Unavailable("Array has no element.")); #endif diff --git a/paddle/pten/kernels/funcs/elementwise_base.h b/paddle/pten/kernels/funcs/elementwise_base.h index 34f2ab4a62a..21abf82c7b2 100644 --- a/paddle/pten/kernels/funcs/elementwise_base.h +++ b/paddle/pten/kernels/funcs/elementwise_base.h @@ -31,8 +31,6 @@ namespace kps = pten::kps; #endif -#define BASE_SIZE 1 // To avoid running errors when Arity == 0 in args[Arity] - namespace pten { enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; @@ -482,7 +480,7 @@ struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { - kps::ElementwiseFillConst(result, func); + kps::ElementwiseConstant(result, func); } }; @@ -560,13 +558,12 @@ template __device__ void VectorizedElementwiseKernelImpl( - const pten::framework::Array &in, + const pten::framework::Array &in, pten::framework::Array<_ptr_ OutT *, NumOuts> outs, int num, int data_offset, Functor func) { - InT args[Arity + BASE_SIZE][VecSize]; + InT args[Arity > 1 ? Arity : 1][VecSize]; ConditionalT result[VecSize]; #pragma unroll @@ -596,8 +593,7 @@ template __global__ void VectorizedElementwiseKernel( - pten::framework::Array - ins, + pten::framework::Array ins, pten::framework::Array<_ptr_ OutT *, NumOuts> outs, int size, int main_offset, @@ -637,9 +633,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { - auto numel = (*outs)[0]->numel(); - pten::framework::Array - ins_data; + auto numel = + (*outs)[0]->numel(); // To avoid running errors when ins.size()== 0 + pten::framework::Array ins_data; pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < Arity; ++i) { diff --git a/paddle/pten/kernels/gpu/full_kernel.cu b/paddle/pten/kernels/gpu/full_kernel.cu index 6464dc97d5b..5dbae41e00c 100644 --- a/paddle/pten/kernels/gpu/full_kernel.cu +++ b/paddle/pten/kernels/gpu/full_kernel.cu @@ -62,10 +62,9 @@ void FullLikeKernel(const ContextT& dev_ctx, auto value = val.to(); using CommonType = typename std::common_type< float, - typename std::conditional< - std::is_same::value, - float, - T>::type>::type; + typename std::conditional::value, + float, + T>::type>::type; auto common_type_value = static_cast(value); @@ -75,7 +74,7 @@ void FullLikeKernel(const ContextT& dev_ctx, (common_type_value <= static_cast(std::numeric_limits::max())), true, - paddle::platform::errors::InvalidArgument( + pten::errors::InvalidArgument( "The filled value is out of range for target type, " "current kernel type is %s, the range should between %f " "and %f, but now value is %f.", diff --git a/paddle/pten/kernels/primitive/compute_primitives.h b/paddle/pten/kernels/primitive/compute_primitives.h index f854cf95aac..449c81b915e 100644 --- a/paddle/pten/kernels/primitive/compute_primitives.h +++ b/paddle/pten/kernels/primitive/compute_primitives.h @@ -420,8 +420,7 @@ template -__device__ __forceinline__ void ElementwiseFillConst(OutT* out, - OpFunc compute) { +__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { #pragma unroll for (int idx = 0; idx < NX * NY; idx++) { out[idx] = static_cast(compute()); diff --git a/paddle/pten/kernels/primitive/compute_primitives_xpu2.h b/paddle/pten/kernels/primitive/compute_primitives_xpu2.h index d7282c089fc..e083280aa2e 100644 --- a/paddle/pten/kernels/primitive/compute_primitives_xpu2.h +++ b/paddle/pten/kernels/primitive/compute_primitives_xpu2.h @@ -348,5 +348,18 @@ __device__ __forceinline__ void Reduce(T* out, } } +template +__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { +#pragma unroll + for (int idx = 0; idx < NX * NY; idx++) { + out[idx] = static_cast(compute()); + } +} + } // namespace kps } // namespace pten -- GitLab