未验证 提交 89d38f55 编写于 作者: L limingshu 提交者: GitHub

Support multi-outputs feature for broadcast ops (#38329)

* No harm to KP

* Pass the compile stage

* change the WriteData function

* fix template bugs and pass ctest of current elementwise

* for passing partial template specialization of tempalte function in CI-ROCm

* To make 'WriteData' funtion flexible.

* a less harmful way to support multi-output

* a less harmful way to support multi-output
上级 f1d56b77
...@@ -254,8 +254,8 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, ...@@ -254,8 +254,8 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
} }
} }
} else { // blockDim,x * NX < num } else { // blockDim,x * NX < num
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize; constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread; int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>; using VecType = details::VectorType<T, kVectorSize>;
...@@ -441,8 +441,8 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, ...@@ -441,8 +441,8 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
} }
} else { } else {
// Vector type // Vector type
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize; constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread; int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>; using VecType = details::VectorType<T, kVectorSize>;
......
...@@ -193,12 +193,13 @@ template <typename InT, ...@@ -193,12 +193,13 @@ template <typename InT,
typename OutT, typename OutT,
typename Functor, typename Functor,
int Arity, int Arity,
int NumOuts,
int VecSize, int VecSize,
int Rank, int Rank,
bool IsBoundary = false> bool IsBoundary = false>
__device__ void ElementwiseBroadcastKernelImpl( __device__ void ElementwiseBroadcastKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &ins, const paddle::framework::Array<const InT *__restrict__, Arity> &ins,
OutT *out, paddle::framework::Array<OutT *, NumOuts> outs,
const paddle::framework::Array<bool, Arity> &use_broadcast, const paddle::framework::Array<bool, Arity> &use_broadcast,
uint32_t numel, uint32_t numel,
const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
...@@ -207,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl( ...@@ -207,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
int block_offset, int block_offset,
Functor func) { Functor func) {
InT args[Arity][VecSize]; InT args[Arity][VecSize];
OutT result[VecSize]; OutType<OutT, NumOuts> result[VecSize];
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
...@@ -220,28 +221,29 @@ __device__ void ElementwiseBroadcastKernelImpl( ...@@ -220,28 +221,29 @@ __device__ void ElementwiseBroadcastKernelImpl(
num, num,
use_broadcast[i]); use_broadcast[i]);
} }
constexpr bool kCallElementwiseAny =
const bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args; paddle::platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT, ElementwisePrimitiveCaller<InT,
OutT, OutType<OutT, NumOuts>,
VecSize, VecSize,
Functor, Functor,
Arity, Arity,
kCallElementwiseAny>()(func, args, result); kCallElementwiseAny>()(func, args, result);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
out + block_offset, result, num); ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, block_offset, num);
} }
template <typename InT, template <typename InT,
typename OutT, typename OutT,
typename Functor, typename Functor,
int Arity, int Arity,
int NumOuts,
int VecSize, int VecSize,
int Rank> int Rank>
__global__ void ElementwiseBroadcastKernel( __global__ void ElementwiseBroadcastKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins, paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out, paddle::framework::Array<OutT *, NumOuts> outs,
paddle::framework::Array<bool, Arity> use_broadcast, paddle::framework::Array<bool, Arity> use_broadcast,
uint32_t numel, uint32_t numel,
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
...@@ -251,16 +253,18 @@ __global__ void ElementwiseBroadcastKernel( ...@@ -251,16 +253,18 @@ __global__ void ElementwiseBroadcastKernel(
Functor func) { Functor func) {
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
for (; block_offset < main_offset; block_offset += stride) { for (; block_offset < main_offset; block_offset += stride) {
ElementwiseBroadcastKernelImpl<InT, ElementwiseBroadcastKernelImpl<InT,
OutT, OutT,
Functor, Functor,
Arity, Arity,
NumOuts,
VecSize, VecSize,
Rank, Rank,
false>(ins, false>(ins,
out, outs,
use_broadcast, use_broadcast,
numel, numel,
configs, configs,
...@@ -273,22 +277,23 @@ __global__ void ElementwiseBroadcastKernel( ...@@ -273,22 +277,23 @@ __global__ void ElementwiseBroadcastKernel(
OutT, OutT,
Functor, Functor,
Arity, Arity,
NumOuts,
VecSize, VecSize,
Rank, Rank,
true>( true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func); ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
} }
#else #else
if (block_offset < main_offset) { if (block_offset < main_offset) {
ElementwiseBroadcastKernelImpl<InT, ElementwiseBroadcastKernelImpl<InT,
OutT, OutT,
Functor, Functor,
Arity, Arity,
NumOuts,
VecSize, VecSize,
Rank, Rank,
false>(ins, false>(ins,
out, outs,
use_broadcast, use_broadcast,
numel, numel,
configs, configs,
...@@ -300,10 +305,11 @@ __global__ void ElementwiseBroadcastKernel( ...@@ -300,10 +305,11 @@ __global__ void ElementwiseBroadcastKernel(
OutT, OutT,
Functor, Functor,
Arity, Arity,
NumOuts,
VecSize, VecSize,
Rank, Rank,
true>( true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func); ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
} }
#endif #endif
} }
...@@ -312,25 +318,30 @@ template <typename InT, ...@@ -312,25 +318,30 @@ template <typename InT,
typename OutT, typename OutT,
typename Functor, typename Functor,
int Arity, int Arity,
int NumOuts,
int VecSize, int VecSize,
int Rank> int Rank>
void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
DenseTensor *out, std::vector<DenseTensor *> *outs,
Functor func, Functor func,
DimensionsTransform merge_dims) { DimensionsTransform merge_dims) {
int numel = out->numel(); int numel = (*outs)[0]->numel();
const int threads = 256; const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_offset = (numel / (VecSize * threads)) * VecSize * threads; int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads); int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream(); auto stream = ctx.stream();
OutT *out_data = out->mutable_data<OutT>();
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs; paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
paddle::framework::Array<bool, Arity> use_broadcast; paddle::framework::Array<bool, Arity> use_broadcast;
paddle::framework::Array<const InT *__restrict__, Arity> ins_data; paddle::framework::Array<const InT *__restrict__, Arity> ins_data;
paddle::framework::Array<OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (*outs)[i]->mutable_data<OutT>();
}
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel); use_broadcast[i] = (ins[i]->numel() != numel);
...@@ -343,6 +354,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -343,6 +354,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
} }
} }
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
threads = 128; threads = 128;
blocks = 8; blocks = 8;
...@@ -352,9 +364,10 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -352,9 +364,10 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
OutT, OutT,
Functor, Functor,
Arity, Arity,
NumOuts,
VecSize, VecSize,
Rank><<<blocks, threads, stream>>>(ins_data, Rank><<<blocks, threads, stream>>>(ins_data,
out_data, outs_data,
use_broadcast, use_broadcast,
numel, numel,
configs, configs,
...@@ -366,10 +379,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -366,10 +379,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
OutT, OutT,
Functor, Functor,
Arity, Arity,
NumOuts,
VecSize, VecSize,
Rank><<<blocks, threads, 0, stream>>>( Rank><<<blocks, threads, 0, stream>>>(
ins_data, ins_data,
out_data, outs_data,
use_broadcast, use_broadcast,
numel, numel,
configs, configs,
...@@ -379,19 +393,24 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -379,19 +393,24 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
#endif #endif
} }
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize> template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
void LaunchBroadcastKernelForDifferentVecSize( void LaunchBroadcastKernelForDifferentVecSize(
const paddle::platform::CUDADeviceContext &ctx, const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
DenseTensor *out, std::vector<DenseTensor *> *outs,
int axis, int axis,
Functor func) { Functor func) {
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ #define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \ case rank: { \
LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>( \ LaunchKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
ctx, ins, out, func, merge_dims); \ ctx, ins, outs, func, merge_dims); \
} break; } break;
switch (merge_dims.dim_size) { switch (merge_dims.dim_size) {
...@@ -414,7 +433,11 @@ void LaunchBroadcastKernelForDifferentVecSize( ...@@ -414,7 +433,11 @@ void LaunchBroadcastKernelForDifferentVecSize(
#undef CALL_BROADCAST_FOR_DIM_SIZE #undef CALL_BROADCAST_FOR_DIM_SIZE
} }
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
void LaunchBroadcastElementwiseCudaKernel( void LaunchBroadcastElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &ctx, const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
...@@ -438,32 +461,68 @@ void LaunchBroadcastElementwiseCudaKernel( ...@@ -438,32 +461,68 @@ void LaunchBroadcastElementwiseCudaKernel(
"Currently only broadcast of binary is supported and " "Currently only broadcast of binary is supported and "
"verified, but received %d.", "verified, but received %d.",
kArity)); kArity));
PADDLE_ENFORCE_EQ(
outs->size(),
NumOuts,
paddle::platform::errors::InvalidArgument(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, number of functions is %d.",
outs->size(),
NumOuts));
int in_vec_size = 4; int in_vec_size = 4;
DenseTensor *out = (*outs)[0]; int out_vec_size = 4;
if (NumOuts > 1) {
for (int i = 0; i < NumOuts; ++i) {
PADDLE_ENFORCE_EQ(
(*outs)[i]->dims(),
(*outs)[0]->dims(),
paddle::platform::errors::InvalidArgument(
"The shape of each output tensor shall be identical yet, but "
"%dth output tensor`s shape is not.",
i));
out_vec_size = std::min(
paddle::platform::GetVectorizedSize<OutT>((*outs)[i]->data<OutT>()),
out_vec_size);
}
} else {
out_vec_size =
paddle::platform::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>());
}
for (auto *in : ins) { for (auto *in : ins) {
auto temp_size = paddle::platform::GetVectorizedSize<InT>(in->data<InT>()); auto temp_size = paddle::platform::GetVectorizedSize<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) in_vec_size = in->dims() == (*outs)[0]->dims()
: in_vec_size; ? std::min(temp_size, in_vec_size)
: in_vec_size;
} }
int out_vec_size =
paddle::platform::GetVectorizedSize<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size); int vec_size = std::min(out_vec_size, in_vec_size);
switch (vec_size) { switch (vec_size) {
case 4: { case 4: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 4>( LaunchBroadcastKernelForDifferentVecSize<InT,
ctx, ins, out, axis, func); OutT,
Functor,
kArity,
NumOuts,
4>(ctx, ins, outs, axis, func);
break; break;
} }
case 2: { case 2: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 2>( LaunchBroadcastKernelForDifferentVecSize<InT,
ctx, ins, out, axis, func); OutT,
Functor,
kArity,
NumOuts,
2>(ctx, ins, outs, axis, func);
break; break;
} }
case 1: { case 1: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 1>( LaunchBroadcastKernelForDifferentVecSize<InT,
ctx, ins, out, axis, func); OutT,
Functor,
kArity,
NumOuts,
1>(ctx, ins, outs, axis, func);
break; break;
} }
default: { default: {
......
...@@ -24,6 +24,12 @@ namespace pten { ...@@ -24,6 +24,12 @@ namespace pten {
namespace kps = paddle::operators::kernel_primitives; namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template <class T, int Num>
using OutType =
typename std::conditional_t<Num == 1, T, paddle::framework::Array<T, Num>>;
template <typename InT, template <typename InT,
typename OutT, typename OutT,
int VecSize, int VecSize,
...@@ -76,4 +82,39 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> { ...@@ -76,4 +82,39 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
} }
}; };
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCaller {
__device__ __forceinline__ void operator()(
paddle::framework::Array<OutT *, NumOuts> outs,
OutType<OutT, NumOuts> src[VecSize],
int block_offset,
int num) {
OutT dst[NumOuts][VecSize];
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
#pragma unroll
for (int j = 0; j < NumOuts; ++j) {
dst[j][i] = (src[i])[j];
}
}
#pragma unroll
for (int i = 0; i < NumOuts; ++i) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[i] + block_offset, dst[i], num);
}
}
};
template <typename OutT, int VecSize, bool IsBoundary>
struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
__device__ __forceinline__ void operator()(
paddle::framework::Array<OutT *, 1> outs,
OutT src[VecSize],
int block_offset,
int num) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[0] + block_offset, src, num);
}
};
} // namespace pten } // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册