未验证 提交 79ee6d63 编写于 作者: N niuliling123 提交者: GitHub

modified the elementwise_op_broadcast and elementwise_op_impl for xpu2 (#37226)

* modified the elementwise_op_broadcast and elementwise_op_impl for xpu2
上级 128bdf66
...@@ -196,7 +196,7 @@ template <typename InT, ...@@ -196,7 +196,7 @@ template <typename InT,
int VecSize, int VecSize,
int Rank, int Rank,
bool IsBoundary = false> bool IsBoundary = false>
__device__ void DealSegment( __device__ void ElementwiseBroadcastKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &ins, const paddle::framework::Array<const InT *__restrict__, Arity> &ins,
OutT *out, OutT *out,
const paddle::framework::Array<bool, Arity> &use_broadcast, const paddle::framework::Array<bool, Arity> &use_broadcast,
...@@ -204,12 +204,11 @@ __device__ void DealSegment( ...@@ -204,12 +204,11 @@ __device__ void DealSegment(
const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
&configs, &configs,
int num, int num,
int block_offset,
Functor func) { Functor func) {
InT args[Arity][VecSize]; InT args[Arity][VecSize];
OutT result[VecSize]; OutT result[VecSize];
int block_offset = blockIdx.x * blockDim.x * VecSize;
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f)); kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
...@@ -240,27 +239,73 @@ template <typename InT, ...@@ -240,27 +239,73 @@ template <typename InT,
int Arity, int Arity,
int VecSize, int VecSize,
int Rank> int Rank>
__global__ void BroadcastKernel( __global__ void ElementwiseBroadcastKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins, paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out, OutT *out,
paddle::framework::Array<bool, Arity> use_broadcast, paddle::framework::Array<bool, Arity> use_broadcast,
uint32_t numel, uint32_t numel,
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
configs, configs,
int main_tid, int main_offset,
int tail_tid, int tail_tid,
Functor func) { Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize; int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
// data offset of this block int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
if (blockIdx.x < main_tid) { #ifdef PADDLE_WITH_XPU2
int num = blockDim.x * VecSize; // blockIdx.x < main_tid for (; block_offset < main_offset; block_offset += stride) {
pten::DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, false>( ElementwiseBroadcastKernelImpl<InT,
ins, out, use_broadcast, numel, configs, num, func); OutT,
} else { // reminder Functor,
int num = tail_tid; Arity,
pten::DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, true>( VecSize,
ins, out, use_broadcast, numel, configs, num, func); Rank,
false>(ins,
out,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
block_offset,
func);
}
if (block_offset < numel) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
} }
#else
if (block_offset < main_offset) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
false>(ins,
out,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
block_offset,
func);
} else {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
#endif
} }
template <typename InT, template <typename InT,
...@@ -278,7 +323,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -278,7 +323,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
const int threads = 256; const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / (VecSize * threads); int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads); int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream(); auto stream = ctx.stream();
OutT *out_data = out->mutable_data<OutT>(); OutT *out_data = out->mutable_data<OutT>();
...@@ -298,20 +343,40 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -298,20 +343,40 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
} }
} }
#ifdef PADDLE_WITH_XPU2
BroadcastKernel<InT, threads = 128;
OutT, blocks = 8;
Functor, main_offset = (numel / (VecSize * threads)) * VecSize * threads;
Arity, tail_tid = numel % (VecSize * threads);
VecSize, ElementwiseBroadcastKernel<InT,
Rank><<<blocks, threads, 0, stream>>>(ins_data, OutT,
out_data, Functor,
use_broadcast, Arity,
numel, VecSize,
configs, Rank><<<blocks, threads, stream>>>(ins_data,
main_tid, out_data,
tail_tid, use_broadcast,
func); numel,
configs,
main_offset,
tail_tid,
func);
#else
ElementwiseBroadcastKernel<InT,
OutT,
Functor,
Arity,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(
ins_data,
out_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#endif
} }
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize> template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
......
...@@ -57,16 +57,15 @@ template <typename InT, ...@@ -57,16 +57,15 @@ template <typename InT,
int Arity, int Arity,
int VecSize, int VecSize,
bool IsBoundary> bool IsBoundary>
__device__ void DealSegment( __device__ void VectorizedElementwiseKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &in, const paddle::framework::Array<const InT *__restrict__, Arity> &in,
OutT *out, OutT *out,
int num, int num,
int data_offset,
Functor func) { Functor func) {
InT args[Arity][VecSize]; InT args[Arity][VecSize];
OutT result[VecSize]; OutT result[VecSize];
int data_offset = VecSize * blockIdx.x * blockDim.x;
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f)); kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
...@@ -87,18 +86,23 @@ __device__ void DealSegment( ...@@ -87,18 +86,23 @@ __device__ void DealSegment(
} }
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize> template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
__global__ void ElementVectorizeKernel( __global__ void VectorizedElementwiseKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins, paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out, OutT *out,
int size, int size,
int main_offset,
Functor func) { Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x; int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, false>(
ins, out, VecSize * BLOCK_NUM_X, data_offset, func);
}
int num = size - data_offset; int num = size - data_offset;
// the num this time have to deal with if (num > 0) {
if (VecSize * blockDim.x > num) { // reminder segment VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, true>(
DealSegment<InT, OutT, Functor, Arity, VecSize, true>(ins, out, num, func); ins, out, num, data_offset, func);
} else { // complete segment
DealSegment<InT, OutT, Functor, Arity, VecSize, false>(ins, out, num, func);
} }
} }
...@@ -132,12 +136,25 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -132,12 +136,25 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
ins_data[i] = ins[i]->data<InT>(); ins_data[i] = ins[i]->data<InT>();
} }
ElementVectorizeKernel<InT, #ifdef PADDLE_WITH_XPU2
OutT, block_size = 128;
Functor, grid_size = 8;
Arity, int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
VecSize><<<grid_size, block_size, 0, stream>>>( VectorizedElementwiseKernel<InT,
ins_data, out_data, numel, func); OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, main_offset, func);
#else
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
VectorizedElementwiseKernel<InT,
OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, main_offset, func);
#endif
} }
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册