From 89d38f5540025fe72b7af532a2b14615c5e86e98 Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Mon, 27 Dec 2021 14:05:24 +0800 Subject: [PATCH] Support multi-outputs feature for broadcast ops (#38329) * No harm to KP * Pass the compile stage * change the WriteData function * fix template bugs and pass ctest of current elementwise * for passing partial template specialization of tempalte function in CI-ROCm * To make 'WriteData' funtion flexible. * a less harmful way to support multi-output * a less harmful way to support multi-output --- .../kernel_primitives/datamover_primitives.h | 8 +- .../elementwise/elementwise_broadcast.cu.h | 135 +++++++++++++----- .../cuda/elementwise/elementwise_common.cu.h | 41 ++++++ 3 files changed, 142 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h index 19355434955..ce45ed0301e 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h @@ -254,8 +254,8 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, } } } else { // blockDim,x * NX < num - const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; - const int kVectorsPerThread = NX / kVectorSize; + constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + constexpr int kVectorsPerThread = NX / kVectorSize; int thread_offset = threadIdx.x * kVectorsPerThread; using VecType = details::VectorType; @@ -441,8 +441,8 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, } } else { // Vector type - const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; - const int kVectorsPerThread = NX / kVectorSize; + constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + constexpr int kVectorsPerThread = NX / kVectorSize; int thread_offset = threadIdx.x * kVectorsPerThread; using VecType = details::VectorType; diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h index ccdeb70002b..3638f1c4349 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h +++ b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h @@ -193,12 +193,13 @@ template __device__ void ElementwiseBroadcastKernelImpl( const paddle::framework::Array &ins, - OutT *out, + paddle::framework::Array outs, const paddle::framework::Array &use_broadcast, uint32_t numel, const paddle::framework::Array, Arity> @@ -207,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl( int block_offset, Functor func) { InT args[Arity][VecSize]; - OutT result[VecSize]; + OutType result[VecSize]; #pragma unroll for (int i = 0; i < Arity; i++) { @@ -220,28 +221,29 @@ __device__ void ElementwiseBroadcastKernelImpl( num, use_broadcast[i]); } - - const bool kCallElementwiseAny = + constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; ElementwisePrimitiveCaller, VecSize, Functor, Arity, kCallElementwiseAny>()(func, args, result); - kps::WriteData( - out + block_offset, result, num); + + ElementwiseWriteDataCaller()( + outs, result, block_offset, num); } template __global__ void ElementwiseBroadcastKernel( paddle::framework::Array ins, - OutT *out, + paddle::framework::Array outs, paddle::framework::Array use_broadcast, uint32_t numel, paddle::framework::Array, Arity> @@ -251,16 +253,18 @@ __global__ void ElementwiseBroadcastKernel( Functor func) { int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + #ifdef PADDLE_WITH_XPU2 for (; block_offset < main_offset; block_offset += stride) { ElementwiseBroadcastKernelImpl(ins, - out, + outs, use_broadcast, numel, configs, @@ -273,22 +277,23 @@ __global__ void ElementwiseBroadcastKernel( OutT, Functor, Arity, + NumOuts, VecSize, Rank, true>( - ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func); + ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); } - #else if (block_offset < main_offset) { ElementwiseBroadcastKernelImpl(ins, - out, + outs, use_broadcast, numel, configs, @@ -300,10 +305,11 @@ __global__ void ElementwiseBroadcastKernel( OutT, Functor, Arity, + NumOuts, VecSize, Rank, true>( - ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func); + ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); } #endif } @@ -312,25 +318,30 @@ template void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, const std::vector &ins, - DenseTensor *out, + std::vector *outs, Functor func, DimensionsTransform merge_dims) { - int numel = out->numel(); + int numel = (*outs)[0]->numel(); const int threads = 256; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int main_offset = (numel / (VecSize * threads)) * VecSize * threads; int tail_tid = numel % (VecSize * threads); auto stream = ctx.stream(); - OutT *out_data = out->mutable_data(); paddle::framework::Array, Arity> configs; paddle::framework::Array use_broadcast; paddle::framework::Array ins_data; + paddle::framework::Array outs_data; + + for (int i = 0; i < NumOuts; ++i) { + outs_data[i] = (*outs)[i]->mutable_data(); + } for (int i = 0; i < Arity; i++) { use_broadcast[i] = (ins[i]->numel() != numel); @@ -343,6 +354,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); } } + #ifdef PADDLE_WITH_XPU2 threads = 128; blocks = 8; @@ -352,9 +364,10 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, OutT, Functor, Arity, + NumOuts, VecSize, Rank><<>>(ins_data, - out_data, + outs_data, use_broadcast, numel, configs, @@ -366,10 +379,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, OutT, Functor, Arity, + NumOuts, VecSize, Rank><<>>( ins_data, - out_data, + outs_data, use_broadcast, numel, configs, @@ -379,19 +393,24 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, #endif } -template +template void LaunchBroadcastKernelForDifferentVecSize( const paddle::platform::CUDADeviceContext &ctx, const std::vector &ins, - DenseTensor *out, + std::vector *outs, int axis, Functor func) { - const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); + const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis); -#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ - case rank: { \ - LaunchKernel( \ - ctx, ins, out, func, merge_dims); \ +#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ + case rank: { \ + LaunchKernel( \ + ctx, ins, outs, func, merge_dims); \ } break; switch (merge_dims.dim_size) { @@ -414,7 +433,11 @@ void LaunchBroadcastKernelForDifferentVecSize( #undef CALL_BROADCAST_FOR_DIM_SIZE } -template +template void LaunchBroadcastElementwiseCudaKernel( const paddle::platform::CUDADeviceContext &ctx, const std::vector &ins, @@ -438,32 +461,68 @@ void LaunchBroadcastElementwiseCudaKernel( "Currently only broadcast of binary is supported and " "verified, but received %d.", kArity)); - + PADDLE_ENFORCE_EQ( + outs->size(), + NumOuts, + paddle::platform::errors::InvalidArgument( + "Number of outputs shall equal to number of functions, " + "but number of outputs is %d, number of functions is %d.", + outs->size(), + NumOuts)); int in_vec_size = 4; - DenseTensor *out = (*outs)[0]; + int out_vec_size = 4; + if (NumOuts > 1) { + for (int i = 0; i < NumOuts; ++i) { + PADDLE_ENFORCE_EQ( + (*outs)[i]->dims(), + (*outs)[0]->dims(), + paddle::platform::errors::InvalidArgument( + "The shape of each output tensor shall be identical yet, but " + "%dth output tensor`s shape is not.", + i)); + out_vec_size = std::min( + paddle::platform::GetVectorizedSize((*outs)[i]->data()), + out_vec_size); + } + } else { + out_vec_size = + paddle::platform::GetVectorizedSize((*outs)[0]->data()); + } + for (auto *in : ins) { auto temp_size = paddle::platform::GetVectorizedSize(in->data()); - in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) - : in_vec_size; + in_vec_size = in->dims() == (*outs)[0]->dims() + ? std::min(temp_size, in_vec_size) + : in_vec_size; } - int out_vec_size = - paddle::platform::GetVectorizedSize(out->data()); int vec_size = std::min(out_vec_size, in_vec_size); switch (vec_size) { case 4: { - LaunchBroadcastKernelForDifferentVecSize( - ctx, ins, out, axis, func); + LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); break; } case 2: { - LaunchBroadcastKernelForDifferentVecSize( - ctx, ins, out, axis, func); + LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); break; } case 1: { - LaunchBroadcastKernelForDifferentVecSize( - ctx, ins, out, axis, func); + LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); break; } default: { diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h index 053b53041d1..18591e897a7 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h +++ b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h @@ -24,6 +24,12 @@ namespace pten { namespace kps = paddle::operators::kernel_primitives; enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; +/* Packing scalar type T(float, int etc.) into Array type + for supporting multiple-output feature in elementwise system.*/ +template +using OutType = + typename std::conditional_t>; + template { } }; +template +struct ElementwiseWriteDataCaller { + __device__ __forceinline__ void operator()( + paddle::framework::Array outs, + OutType src[VecSize], + int block_offset, + int num) { + OutT dst[NumOuts][VecSize]; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { +#pragma unroll + for (int j = 0; j < NumOuts; ++j) { + dst[j][i] = (src[i])[j]; + } + } +#pragma unroll + for (int i = 0; i < NumOuts; ++i) { + kps::WriteData( + outs[i] + block_offset, dst[i], num); + } + } +}; + +template +struct ElementwiseWriteDataCaller { + __device__ __forceinline__ void operator()( + paddle::framework::Array outs, + OutT src[VecSize], + int block_offset, + int num) { + kps::WriteData( + outs[0] + block_offset, src, num); + } +}; + } // namespace pten -- GitLab