未验证 提交 b007a031 编写于 作者: N niuliling123 提交者: GitHub

Delete BASE_SIZE in elementwise_base.h (#39390)

上级 2be20e20
......@@ -104,28 +104,22 @@ class Array<T, 0> {
HOSTDEVICE inline T *GetMutable() { return nullptr; }
HOSTDEVICE inline T &operator[](size_t) {
#if defined(__HIPCC__)
// HIP will have compile error, if use "obj()"
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
// HIP and CUDA will have compile error, if use "obj()"
// function declared in block scope cannot have 'static' storage class
static T obj{};
return obj;
#elif defined(__CUDA_ARCH__)
static T obj();
return obj;
#else
PADDLE_THROW(pten::errors::Unavailable("Array<T, 0> has no element."));
#endif
}
HOSTDEVICE inline const T &operator[](size_t) const {
#if defined(__HIPCC__)
// HIP will have compile error, if use "obj()"
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
// HIP and CUDA will have compile error, if use "obj()"
// function declared in block scope cannot have 'static' storage class
static const T obj{};
return obj;
#elif defined(__CUDA_ARCH__)
static const T obj();
return obj;
#else
PADDLE_THROW(pten::errors::Unavailable("Array<T, 0> has no element."));
#endif
......
......@@ -31,8 +31,6 @@ namespace kps = pten::kps;
#endif
#define BASE_SIZE 1 // To avoid running errors when Arity == 0 in args[Arity]
namespace pten {
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
......@@ -482,7 +480,7 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseFillConst<InT, OutT, VecSize, 1, 1, Functor>(result, func);
kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func);
}
};
......@@ -560,13 +558,12 @@ template <typename InT,
bool IsBoundary>
__device__ void VectorizedElementwiseKernelImpl(
const pten::framework::Array<const _ptr_ InT *__restrict__,
Arity + BASE_SIZE> &in,
const pten::framework::Array<const _ptr_ InT *__restrict__, Arity> &in,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs,
int num,
int data_offset,
Functor func) {
InT args[Arity + BASE_SIZE][VecSize];
InT args[Arity > 1 ? Arity : 1][VecSize];
ConditionalT<OutT, NumOuts> result[VecSize];
#pragma unroll
......@@ -596,8 +593,7 @@ template <typename InT,
int NumOuts,
int VecSize>
__global__ void VectorizedElementwiseKernel(
pten::framework::Array<const _ptr_ InT *__restrict__, Arity + BASE_SIZE>
ins,
pten::framework::Array<const _ptr_ InT *__restrict__, Arity> ins,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs,
int size,
int main_offset,
......@@ -637,9 +633,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func) {
auto numel = (*outs)[0]->numel();
pten::framework::Array<const _ptr_ InT *__restrict__, Arity + BASE_SIZE>
ins_data;
auto numel =
(*outs)[0]->numel(); // To avoid running errors when ins.size()== 0
pten::framework::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < Arity; ++i) {
......
......@@ -62,8 +62,7 @@ void FullLikeKernel(const ContextT& dev_ctx,
auto value = val.to<float>();
using CommonType = typename std::common_type<
float,
typename std::conditional<
std::is_same<T, paddle::platform::float16>::value,
typename std::conditional<std::is_same<T, pten::dtype::float16>::value,
float,
T>::type>::type;
......@@ -75,7 +74,7 @@ void FullLikeKernel(const ContextT& dev_ctx,
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
paddle::platform::errors::InvalidArgument(
pten::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
"and %f, but now value is %f.",
......
......@@ -420,8 +420,7 @@ template <typename InT,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseFillConst(OutT* out,
OpFunc compute) {
__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute());
......
......@@ -348,5 +348,18 @@ __device__ __forceinline__ void Reduce(T* out,
}
}
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute());
}
}
} // namespace kps
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册