未验证 提交 b007a031 编写于 作者: N niuliling123 提交者: GitHub

Delete BASE_SIZE in elementwise_base.h (#39390)

上级 2be20e20
...@@ -104,28 +104,22 @@ class Array<T, 0> { ...@@ -104,28 +104,22 @@ class Array<T, 0> {
HOSTDEVICE inline T *GetMutable() { return nullptr; } HOSTDEVICE inline T *GetMutable() { return nullptr; }
HOSTDEVICE inline T &operator[](size_t) { HOSTDEVICE inline T &operator[](size_t) {
#if defined(__HIPCC__) #if defined(__HIPCC__) || defined(__CUDA_ARCH__)
// HIP will have compile error, if use "obj()" // HIP and CUDA will have compile error, if use "obj()"
// function declared in block scope cannot have 'static' storage class // function declared in block scope cannot have 'static' storage class
static T obj{}; static T obj{};
return obj; return obj;
#elif defined(__CUDA_ARCH__)
static T obj();
return obj;
#else #else
PADDLE_THROW(pten::errors::Unavailable("Array<T, 0> has no element.")); PADDLE_THROW(pten::errors::Unavailable("Array<T, 0> has no element."));
#endif #endif
} }
HOSTDEVICE inline const T &operator[](size_t) const { HOSTDEVICE inline const T &operator[](size_t) const {
#if defined(__HIPCC__) #if defined(__HIPCC__) || defined(__CUDA_ARCH__)
// HIP will have compile error, if use "obj()" // HIP and CUDA will have compile error, if use "obj()"
// function declared in block scope cannot have 'static' storage class // function declared in block scope cannot have 'static' storage class
static const T obj{}; static const T obj{};
return obj; return obj;
#elif defined(__CUDA_ARCH__)
static const T obj();
return obj;
#else #else
PADDLE_THROW(pten::errors::Unavailable("Array<T, 0> has no element.")); PADDLE_THROW(pten::errors::Unavailable("Array<T, 0> has no element."));
#endif #endif
......
...@@ -31,8 +31,6 @@ namespace kps = pten::kps; ...@@ -31,8 +31,6 @@ namespace kps = pten::kps;
#endif #endif
#define BASE_SIZE 1 // To avoid running errors when Arity == 0 in args[Arity]
namespace pten { namespace pten {
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
...@@ -482,7 +480,7 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> { ...@@ -482,7 +480,7 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result) {
kps::ElementwiseFillConst<InT, OutT, VecSize, 1, 1, Functor>(result, func); kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func);
} }
}; };
...@@ -560,13 +558,12 @@ template <typename InT, ...@@ -560,13 +558,12 @@ template <typename InT,
bool IsBoundary> bool IsBoundary>
__device__ void VectorizedElementwiseKernelImpl( __device__ void VectorizedElementwiseKernelImpl(
const pten::framework::Array<const _ptr_ InT *__restrict__, const pten::framework::Array<const _ptr_ InT *__restrict__, Arity> &in,
Arity + BASE_SIZE> &in,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs, pten::framework::Array<_ptr_ OutT *, NumOuts> outs,
int num, int num,
int data_offset, int data_offset,
Functor func) { Functor func) {
InT args[Arity + BASE_SIZE][VecSize]; InT args[Arity > 1 ? Arity : 1][VecSize];
ConditionalT<OutT, NumOuts> result[VecSize]; ConditionalT<OutT, NumOuts> result[VecSize];
#pragma unroll #pragma unroll
...@@ -596,8 +593,7 @@ template <typename InT, ...@@ -596,8 +593,7 @@ template <typename InT,
int NumOuts, int NumOuts,
int VecSize> int VecSize>
__global__ void VectorizedElementwiseKernel( __global__ void VectorizedElementwiseKernel(
pten::framework::Array<const _ptr_ InT *__restrict__, Arity + BASE_SIZE> pten::framework::Array<const _ptr_ InT *__restrict__, Arity> ins,
ins,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs, pten::framework::Array<_ptr_ OutT *, NumOuts> outs,
int size, int size,
int main_offset, int main_offset,
...@@ -637,9 +633,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx, ...@@ -637,9 +633,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
Functor func) { Functor func) {
auto numel = (*outs)[0]->numel(); auto numel =
pten::framework::Array<const _ptr_ InT *__restrict__, Arity + BASE_SIZE> (*outs)[0]->numel(); // To avoid running errors when ins.size()== 0
ins_data; pten::framework::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data; pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < Arity; ++i) { for (int i = 0; i < Arity; ++i) {
......
...@@ -62,10 +62,9 @@ void FullLikeKernel(const ContextT& dev_ctx, ...@@ -62,10 +62,9 @@ void FullLikeKernel(const ContextT& dev_ctx,
auto value = val.to<float>(); auto value = val.to<float>();
using CommonType = typename std::common_type< using CommonType = typename std::common_type<
float, float,
typename std::conditional< typename std::conditional<std::is_same<T, pten::dtype::float16>::value,
std::is_same<T, paddle::platform::float16>::value, float,
float, T>::type>::type;
T>::type>::type;
auto common_type_value = static_cast<CommonType>(value); auto common_type_value = static_cast<CommonType>(value);
...@@ -75,7 +74,7 @@ void FullLikeKernel(const ContextT& dev_ctx, ...@@ -75,7 +74,7 @@ void FullLikeKernel(const ContextT& dev_ctx,
(common_type_value <= (common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())), static_cast<CommonType>(std::numeric_limits<T>::max())),
true, true,
paddle::platform::errors::InvalidArgument( pten::errors::InvalidArgument(
"The filled value is out of range for target type, " "The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f " "current kernel type is %s, the range should between %f "
"and %f, but now value is %f.", "and %f, but now value is %f.",
......
...@@ -420,8 +420,7 @@ template <typename InT, ...@@ -420,8 +420,7 @@ template <typename InT,
int NY, int NY,
int BlockSize, int BlockSize,
class OpFunc> class OpFunc>
__device__ __forceinline__ void ElementwiseFillConst(OutT* out, __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
OpFunc compute) {
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX * NY; idx++) { for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute()); out[idx] = static_cast<OutT>(compute());
......
...@@ -348,5 +348,18 @@ __device__ __forceinline__ void Reduce(T* out, ...@@ -348,5 +348,18 @@ __device__ __forceinline__ void Reduce(T* out,
} }
} }
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute());
}
}
} // namespace kps } // namespace kps
} // namespace pten } // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册