未验证 提交 89d38f55 编写于 作者: L limingshu 提交者: GitHub

Support multi-outputs feature for broadcast ops (#38329)

* No harm to KP

* Pass the compile stage

* change the WriteData function

* fix template bugs and pass ctest of current elementwise

* for passing partial template specialization of tempalte function in CI-ROCm

* To make 'WriteData' funtion flexible.

* a less harmful way to support multi-output

* a less harmful way to support multi-output
上级 f1d56b77
......@@ -254,8 +254,8 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
}
}
} else { // blockDim,x * NX < num
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize;
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
......@@ -441,8 +441,8 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
}
} else {
// Vector type
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize;
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
......
......@@ -193,12 +193,13 @@ template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank,
bool IsBoundary = false>
__device__ void ElementwiseBroadcastKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &ins,
OutT *out,
paddle::framework::Array<OutT *, NumOuts> outs,
const paddle::framework::Array<bool, Arity> &use_broadcast,
uint32_t numel,
const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
......@@ -207,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
int block_offset,
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
OutType<OutT, NumOuts> result[VecSize];
#pragma unroll
for (int i = 0; i < Arity; i++) {
......@@ -220,28 +221,29 @@ __device__ void ElementwiseBroadcastKernelImpl(
num,
use_broadcast[i]);
}
const bool kCallElementwiseAny =
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT,
OutT,
OutType<OutT, NumOuts>,
VecSize,
Functor,
Arity,
kCallElementwiseAny>()(func, args, result);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
out + block_offset, result, num);
ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, block_offset, num);
}
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
__global__ void ElementwiseBroadcastKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out,
paddle::framework::Array<OutT *, NumOuts> outs,
paddle::framework::Array<bool, Arity> use_broadcast,
uint32_t numel,
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
......@@ -251,16 +253,18 @@ __global__ void ElementwiseBroadcastKernel(
Functor func) {
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
#ifdef PADDLE_WITH_XPU2
for (; block_offset < main_offset; block_offset += stride) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
out,
outs,
use_broadcast,
numel,
configs,
......@@ -273,22 +277,23 @@ __global__ void ElementwiseBroadcastKernel(
OutT,
Functor,
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
#else
if (block_offset < main_offset) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
out,
outs,
use_broadcast,
numel,
configs,
......@@ -300,10 +305,11 @@ __global__ void ElementwiseBroadcastKernel(
OutT,
Functor,
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
#endif
}
......@@ -312,25 +318,30 @@ template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
DenseTensor *out,
std::vector<DenseTensor *> *outs,
Functor func,
DimensionsTransform merge_dims) {
int numel = out->numel();
int numel = (*outs)[0]->numel();
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream();
OutT *out_data = out->mutable_data<OutT>();
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
paddle::framework::Array<bool, Arity> use_broadcast;
paddle::framework::Array<const InT *__restrict__, Arity> ins_data;
paddle::framework::Array<OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (*outs)[i]->mutable_data<OutT>();
}
for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel);
......@@ -343,6 +354,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
#ifdef PADDLE_WITH_XPU2
threads = 128;
blocks = 8;
......@@ -352,9 +364,10 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, stream>>>(ins_data,
out_data,
outs_data,
use_broadcast,
numel,
configs,
......@@ -366,10 +379,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(
ins_data,
out_data,
outs_data,
use_broadcast,
numel,
configs,
......@@ -379,19 +393,24 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
#endif
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
void LaunchBroadcastKernelForDifferentVecSize(
const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
DenseTensor *out,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>( \
ctx, ins, out, func, merge_dims); \
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
ctx, ins, outs, func, merge_dims); \
} break;
switch (merge_dims.dim_size) {
......@@ -414,7 +433,11 @@ void LaunchBroadcastKernelForDifferentVecSize(
#undef CALL_BROADCAST_FOR_DIM_SIZE
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
void LaunchBroadcastElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
......@@ -438,32 +461,68 @@ void LaunchBroadcastElementwiseCudaKernel(
"Currently only broadcast of binary is supported and "
"verified, but received %d.",
kArity));
PADDLE_ENFORCE_EQ(
outs->size(),
NumOuts,
paddle::platform::errors::InvalidArgument(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, number of functions is %d.",
outs->size(),
NumOuts));
int in_vec_size = 4;
DenseTensor *out = (*outs)[0];
int out_vec_size = 4;
if (NumOuts > 1) {
for (int i = 0; i < NumOuts; ++i) {
PADDLE_ENFORCE_EQ(
(*outs)[i]->dims(),
(*outs)[0]->dims(),
paddle::platform::errors::InvalidArgument(
"The shape of each output tensor shall be identical yet, but "
"%dth output tensor`s shape is not.",
i));
out_vec_size = std::min(
paddle::platform::GetVectorizedSize<OutT>((*outs)[i]->data<OutT>()),
out_vec_size);
}
} else {
out_vec_size =
paddle::platform::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>());
}
for (auto *in : ins) {
auto temp_size = paddle::platform::GetVectorizedSize<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
in_vec_size = in->dims() == (*outs)[0]->dims()
? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size =
paddle::platform::GetVectorizedSize<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);
switch (vec_size) {
case 4: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 4>(
ctx, ins, out, axis, func);
LaunchBroadcastKernelForDifferentVecSize<InT,
OutT,
Functor,
kArity,
NumOuts,
4>(ctx, ins, outs, axis, func);
break;
}
case 2: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 2>(
ctx, ins, out, axis, func);
LaunchBroadcastKernelForDifferentVecSize<InT,
OutT,
Functor,
kArity,
NumOuts,
2>(ctx, ins, outs, axis, func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 1>(
ctx, ins, out, axis, func);
LaunchBroadcastKernelForDifferentVecSize<InT,
OutT,
Functor,
kArity,
NumOuts,
1>(ctx, ins, outs, axis, func);
break;
}
default: {
......
......@@ -24,6 +24,12 @@ namespace pten {
namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template <class T, int Num>
using OutType =
typename std::conditional_t<Num == 1, T, paddle::framework::Array<T, Num>>;
template <typename InT,
typename OutT,
int VecSize,
......@@ -76,4 +82,39 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
}
};
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCaller {
__device__ __forceinline__ void operator()(
paddle::framework::Array<OutT *, NumOuts> outs,
OutType<OutT, NumOuts> src[VecSize],
int block_offset,
int num) {
OutT dst[NumOuts][VecSize];
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
#pragma unroll
for (int j = 0; j < NumOuts; ++j) {
dst[j][i] = (src[i])[j];
}
}
#pragma unroll
for (int i = 0; i < NumOuts; ++i) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[i] + block_offset, dst[i], num);
}
}
};
template <typename OutT, int VecSize, bool IsBoundary>
struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
__device__ __forceinline__ void operator()(
paddle::framework::Array<OutT *, 1> outs,
OutT src[VecSize],
int block_offset,
int num) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[0] + block_offset, src, num);
}
};
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册