diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 30aba42aeee110fb9423053973b59b1929f0075e..89c97b1b4ca81c321b0738d6a08070e4cf9220aa 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -162,7 +162,8 @@ struct DimensionsTransform { } }; -template +template void LaunchBroadcastElementwiseCudaKernel( const platform::CUDADeviceContext &ctx, const std::vector &ins, @@ -190,11 +191,12 @@ void LaunchBroadcastElementwiseCudaKernel( for (int i = 0; i < pt_outputs_tmp.size(); i++) { pt_outputs.push_back(pt_outputs_tmp[i].get()); } - pten::LaunchBroadcastElementwiseCudaKernel( + pten::LaunchBroadcastElementwiseCudaKernel( ctx, pt_inputs, &pt_outputs, axis, func); } -template +template void LaunchElementwiseCudaKernel( const platform::CUDADeviceContext &cuda_ctx, const std::vector &ins, @@ -222,8 +224,8 @@ void LaunchElementwiseCudaKernel( for (int i = 0; i < pt_outputs_tmp.size(); i++) { pt_outputs.push_back(pt_outputs_tmp[i].get()); } - pten::LaunchElementwiseCudaKernel(cuda_ctx, pt_inputs, - &pt_outputs, axis, func); + pten::LaunchElementwiseCudaKernel( + cuda_ctx, pt_inputs, &pt_outputs, axis, func); } } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 12fdcd40aa0b1dab1b9566d2074e60ab38a65943..8f4a9dea55bd3a5524c6926470069caebb1949d9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -38,7 +38,8 @@ namespace kps = paddle::operators::kernel_primitives; using ElementwiseType = pten::ElementwiseType; -template +template void LaunchSameDimsElementwiseCudaKernel( const platform::CUDADeviceContext &ctx, const std::vector &ins, @@ -66,8 +67,8 @@ void LaunchSameDimsElementwiseCudaKernel( for (int i = 0; i < pt_outputs_tmp.size(); i++) { pt_outputs.push_back(pt_outputs_tmp[i].get()); } - pten::LaunchSameDimsElementwiseCudaKernel(ctx, pt_inputs, - &pt_outputs, func); + pten::LaunchSameDimsElementwiseCudaKernel( + ctx, pt_inputs, &pt_outputs, func); } } // namespace operators diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h index 0ef2ee2fdf1f48a102ac3bf8213193eec590952c..83d662b14e7fc5b06d8a92886e1b5820cd342bb9 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h +++ b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h @@ -19,7 +19,11 @@ limitations under the License. */ namespace pten { -template +template void LaunchElementwiseCudaKernel( const paddle::platform::CUDADeviceContext &cuda_ctx, const std::vector &ins, @@ -33,14 +37,14 @@ void LaunchElementwiseCudaKernel( dims_size.emplace_back(in->dims().size()); } if (no_broadcast_flag) { - LaunchSameDimsElementwiseCudaKernel( + LaunchSameDimsElementwiseCudaKernel( cuda_ctx, ins, outs, func); } else { axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) - *std::min_element(dims_size.begin(), dims_size.end()) : axis; - LaunchBroadcastElementwiseCudaKernel( + LaunchBroadcastElementwiseCudaKernel( cuda_ctx, ins, outs, axis, func); } } diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h index 3638f1c434933b140743597184ff88da228616b4..9303cf1c7fc479f6c7e49d7f03608bf6fb5a61aa 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h +++ b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h @@ -208,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl( int block_offset, Functor func) { InT args[Arity][VecSize]; - OutType result[VecSize]; + ConditionalT result[VecSize]; #pragma unroll for (int i = 0; i < Arity; i++) { @@ -224,7 +224,7 @@ __device__ void ElementwiseBroadcastKernelImpl( constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; ElementwisePrimitiveCaller, + ConditionalT, VecSize, Functor, Arity, @@ -455,20 +455,19 @@ void LaunchBroadcastElementwiseCudaKernel( "is %d, the arity of functor is %d.", ins.size(), kArity)); - PADDLE_ENFORCE_EQ(kArity, - 2, + PADDLE_ENFORCE_LE(kArity, + ElementwiseType::kTernary, paddle::platform::errors::InvalidArgument( - "Currently only broadcast of binary is supported and " - "verified, but received %d.", + "Currently only broadcast of ternary is supported " + "and verified, but received %d.", kArity)); - PADDLE_ENFORCE_EQ( - outs->size(), - NumOuts, - paddle::platform::errors::InvalidArgument( - "Number of outputs shall equal to number of functions, " - "but number of outputs is %d, number of functions is %d.", - outs->size(), - NumOuts)); + PADDLE_ENFORCE_EQ(outs->size(), + NumOuts, + paddle::platform::errors::InvalidArgument( + "Number of outputs shall equal to number of functions, " + "but number of outputs is %d, of functions is %d.", + outs->size(), + NumOuts)); int in_vec_size = 4; int out_vec_size = 4; if (NumOuts > 1) { diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h index 18591e897a7f0c0ac4548261ecd4c80c1e707037..7c5f3a9778404e706ee706ec1e29076add24c8c1 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h +++ b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h @@ -27,7 +27,7 @@ enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; /* Packing scalar type T(float, int etc.) into Array type for supporting multiple-output feature in elementwise system.*/ template -using OutType = +using ConditionalT = typename std::conditional_t>; template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( paddle::framework::Array outs, - OutType src[VecSize], + ConditionalT src[VecSize], int block_offset, int num) { OutT dst[NumOuts][VecSize]; diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h index e2659271bdcd9f873fdec524e47e51758c150e3c..f37e3b0b5e3b36c0381810ef167b4038809bfad8 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h +++ b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h @@ -55,16 +55,17 @@ template __device__ void VectorizedElementwiseKernelImpl( const paddle::framework::Array &in, - OutT *out, + paddle::framework::Array outs, int num, int data_offset, Functor func) { InT args[Arity][VecSize]; - OutT result[VecSize]; + ConditionalT result[VecSize]; #pragma unroll for (int i = 0; i < Arity; i++) { @@ -73,36 +74,53 @@ __device__ void VectorizedElementwiseKernelImpl( args[i], in[i] + data_offset, num); } - const bool kCallElementwiseAny = + constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; ElementwisePrimitiveCaller, VecSize, Functor, Arity, kCallElementwiseAny>()(func, args, result); - kps::WriteData( - out + data_offset, result, num); + + ElementwiseWriteDataCaller()( + outs, result, data_offset, num); } -template +template __global__ void VectorizedElementwiseKernel( paddle::framework::Array ins, - OutT *out, + paddle::framework::Array outs, int size, int main_offset, Functor func) { int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; for (; data_offset < main_offset; data_offset += stride) { - VectorizedElementwiseKernelImpl( - ins, out, VecSize * BLOCK_NUM_X, data_offset, func); + VectorizedElementwiseKernelImpl( + ins, outs, VecSize * BLOCK_NUM_X, data_offset, func); } int num = size - data_offset; if (num > 0) { - VectorizedElementwiseKernelImpl( - ins, out, num, data_offset, func); + VectorizedElementwiseKernelImpl(ins, outs, num, data_offset, func); } } @@ -121,7 +139,12 @@ int GetVectorizedSizeForTensors(const std::vector &ins, return vec_size; } -template +template void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, const std::vector &ins, std::vector *outs, @@ -131,11 +154,15 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, int grid_size = ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; auto stream = ctx.stream(); - OutT *out_data = (*outs)[0]->mutable_data(); paddle::framework::Array ins_data; - for (int i = 0; i < Arity; i++) { + paddle::framework::Array outs_data; + + for (int i = 0; i < Arity; ++i) { ins_data[i] = ins[i]->data(); } + for (int i = 0; i < NumOuts; ++i) { + outs_data[i] = (*outs)[i]->mutable_data(); + } #ifdef PADDLE_WITH_XPU2 block_size = 128; grid_size = 8; @@ -144,20 +171,26 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, OutT, Functor, Arity, + NumOuts, VecSize><<>>( - ins_data, out_data, numel, main_offset, func); + ins_data, outs_data, numel, main_offset, func); #else int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; VectorizedElementwiseKernel<<>>( - ins_data, out_data, numel, main_offset, func); + ins_data, outs_data, numel, main_offset, func); #endif } -template +template void LaunchSameDimsElementwiseCudaKernel( const paddle::platform::CUDADeviceContext &ctx, const std::vector &ins, @@ -174,19 +207,39 @@ void LaunchSameDimsElementwiseCudaKernel( "is %d, the arity of functor is %d.", ins.size(), kArity)); + PADDLE_ENFORCE_EQ(outs->size(), + NumOuts, + paddle::platform::errors::InvalidArgument( + "Number of outputs shall equal to number of functions, " + "but number of outputs is %d, of functions is %d.", + outs->size(), + NumOuts)); + + if (NumOuts > 1) { + for (int i = 1; i < NumOuts; ++i) { + PADDLE_ENFORCE_EQ( + (*outs)[i]->dims(), + (*outs)[0]->dims(), + paddle::platform::errors::InvalidArgument( + "The shape of each output tensor shall be identical yet, " + "but %dth output tensor`s shape is not.", + i)); + } + } + // calculate the max vec_size for all ins and outs int vec_size = GetVectorizedSizeForTensors(ins, *outs); switch (vec_size) { case 4: - ElementwiseCudaKernel( + ElementwiseCudaKernel( ctx, ins, outs, func); break; case 2: - ElementwiseCudaKernel( + ElementwiseCudaKernel( ctx, ins, outs, func); break; case 1: - ElementwiseCudaKernel( + ElementwiseCudaKernel( ctx, ins, outs, func); break; default: {