未验证 提交 79ee6d63 编写于 作者: N niuliling123 提交者: GitHub

modified the elementwise_op_broadcast and elementwise_op_impl for xpu2 (#37226)

* modified the elementwise_op_broadcast and elementwise_op_impl for xpu2
上级 128bdf66
......@@ -196,7 +196,7 @@ template <typename InT,
int VecSize,
int Rank,
bool IsBoundary = false>
__device__ void DealSegment(
__device__ void ElementwiseBroadcastKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &ins,
OutT *out,
const paddle::framework::Array<bool, Arity> &use_broadcast,
......@@ -204,12 +204,11 @@ __device__ void DealSegment(
const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
&configs,
int num,
int block_offset,
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
int block_offset = blockIdx.x * blockDim.x * VecSize;
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
......@@ -240,27 +239,73 @@ template <typename InT,
int Arity,
int VecSize,
int Rank>
__global__ void BroadcastKernel(
__global__ void ElementwiseBroadcastKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out,
paddle::framework::Array<bool, Arity> use_broadcast,
uint32_t numel,
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
configs,
int main_tid,
int main_offset,
int tail_tid,
Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block
if (blockIdx.x < main_tid) {
int num = blockDim.x * VecSize; // blockIdx.x < main_tid
pten::DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, false>(
ins, out, use_broadcast, numel, configs, num, func);
} else { // reminder
int num = tail_tid;
pten::DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, true>(
ins, out, use_broadcast, numel, configs, num, func);
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
#ifdef PADDLE_WITH_XPU2
for (; block_offset < main_offset; block_offset += stride) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
false>(ins,
out,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
block_offset,
func);
}
if (block_offset < numel) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
#else
if (block_offset < main_offset) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
false>(ins,
out,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
block_offset,
func);
} else {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
#endif
}
template <typename InT,
......@@ -278,7 +323,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / (VecSize * threads);
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream();
OutT *out_data = out->mutable_data<OutT>();
......@@ -298,20 +343,40 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
BroadcastKernel<InT,
OutT,
Functor,
Arity,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(ins_data,
out_data,
use_broadcast,
numel,
configs,
main_tid,
tail_tid,
func);
#ifdef PADDLE_WITH_XPU2
threads = 128;
blocks = 8;
main_offset = (numel / (VecSize * threads)) * VecSize * threads;
tail_tid = numel % (VecSize * threads);
ElementwiseBroadcastKernel<InT,
OutT,
Functor,
Arity,
VecSize,
Rank><<<blocks, threads, stream>>>(ins_data,
out_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#else
ElementwiseBroadcastKernel<InT,
OutT,
Functor,
Arity,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(
ins_data,
out_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#endif
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
......
......@@ -57,16 +57,15 @@ template <typename InT,
int Arity,
int VecSize,
bool IsBoundary>
__device__ void DealSegment(
__device__ void VectorizedElementwiseKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &in,
OutT *out,
int num,
int data_offset,
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
int data_offset = VecSize * blockIdx.x * blockDim.x;
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
......@@ -87,18 +86,23 @@ __device__ void DealSegment(
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
__global__ void ElementVectorizeKernel(
__global__ void VectorizedElementwiseKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out,
int size,
int main_offset,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, false>(
ins, out, VecSize * BLOCK_NUM_X, data_offset, func);
}
int num = size - data_offset;
// the num this time have to deal with
if (VecSize * blockDim.x > num) { // reminder segment
DealSegment<InT, OutT, Functor, Arity, VecSize, true>(ins, out, num, func);
} else { // complete segment
DealSegment<InT, OutT, Functor, Arity, VecSize, false>(ins, out, num, func);
if (num > 0) {
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, true>(
ins, out, num, data_offset, func);
}
}
......@@ -132,12 +136,25 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
for (int i = 0; i < Arity; i++) {
ins_data[i] = ins[i]->data<InT>();
}
ElementVectorizeKernel<InT,
OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, func);
#ifdef PADDLE_WITH_XPU2
block_size = 128;
grid_size = 8;
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
VectorizedElementwiseKernel<InT,
OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, main_offset, func);
#else
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
VectorizedElementwiseKernel<InT,
OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, main_offset, func);
#endif
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册