未验证 提交 eae4bf5b 编写于 作者: N niuliling123 提交者: GitHub

Modify the elementwise op according to the kernel primitive API (#34456)

上级 b211f02b
...@@ -15,10 +15,14 @@ ...@@ -15,10 +15,14 @@
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define MAX_INPUT_NUM 3 // the max num of ET for BroadcacstConfig
namespace kps = paddle::operators::kernel_primitives;
struct DimensionsTransform { struct DimensionsTransform {
using DimVector = std::vector<int64_t>; using DimVector = std::vector<int64_t>;
typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &, typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
...@@ -161,202 +165,113 @@ struct DimensionsTransform { ...@@ -161,202 +165,113 @@ struct DimensionsTransform {
} }
}; };
struct StridesCalculation { template <typename T, int VecSize, int ShapeSize, bool IsBoundary = false>
std::vector<std::vector<uint32_t>> strides; __device__ __forceinline__ void LoadData(
std::vector<platform::FastDivMod> divmoders; T *dst, const T *__restrict__ src, uint32_t block_offset,
const kps::details::BroadcastConfig<ShapeSize> &config, int numel, int num,
private: bool need_broadcast) {
// To calculate the strides of each input_tensor. // numel : whole num of output
__inline__ void CalculateStrides( // num: how many data will be deal with in this time
int N, int dim_size, const std::vector<std::vector<int64_t>> &in_dims) { if (need_broadcast) {
for (int j = 0; j < N; ++j) { kps::ReadDataBc<T, VecSize, 1, 1, ShapeSize, IsBoundary>(
for (int i = 0; i < dim_size; ++i) { dst, src, block_offset, config, numel, 1, 1);
strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i]; } else {
strides[j][i] = kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
(i != 0 && strides[j][i] != 0)
? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1,
std::multiplies<int64_t>())
: strides[j][i];
}
}
}
public:
explicit StridesCalculation(const int64_t &dim_size,
const std::vector<std::vector<int64_t>> &in_dims,
const std::vector<int64_t> &out_dims) {
const auto N = in_dims.size();
divmoders.resize(dim_size);
strides.resize(N, std::vector<uint32_t>(dim_size, 1));
for (int i = 0; i < dim_size; ++i) {
divmoders[i] = platform::FastDivMod(out_dims[i]);
}
CalculateStrides(N, dim_size, in_dims);
}
};
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int VecSize, int kDims>
struct BroadcastArgsWrapper {
using InVecType = platform::AlignedVector<InT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;
OutT *out_data;
OutVecType *vec_out_data;
const InT *__restrict__ in_data[ET];
const InVecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET];
platform::FastDivMod divmoders[kDims];
uint32_t strides[ET][framework::DDim::kMaxRank];
uint32_t scalar_cal_offset;
Functor func;
HOSTDEVICE BroadcastArgsWrapper(
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int scalar_cal_offset, Functor func,
const StridesCalculation &offset_calculator)
: scalar_cal_offset(scalar_cal_offset), func(func) {
for (int j = 0; j < ET; ++j) {
in_data[j] = ins[j]->data<InT>();
vec_in_data[j] = reinterpret_cast<const InVecType *>(in_data[j]);
no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
memcpy(strides[j], offset_calculator.strides[j].data(),
kDims * sizeof(uint32_t));
}
out_data = out->data<OutT>();
vec_out_data = reinterpret_cast<OutVecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(platform::FastDivMod));
}
__device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
uint32_t offset = 0;
#pragma unroll(kDims)
for (int i = 0; i < kDims; ++i) {
auto fast_divmoder = divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
offset += fast_divmoder.val[1] * strides[in_idx][i];
}
return offset;
}
__device__ __forceinline__ void LoadVectorizedDataCommon(
InVecType *vector_args, int tid, int idx) {
*vector_args = vec_in_data[idx][tid];
}
__device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args,
int tid, int idx) {
int index = tid * VecSize;
#pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) {
uint32_t offset = GetOffsetByDivmod(index + i, idx);
scalar_args[i] = in_data[idx][offset];
}
}
__device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid,
int idx) {
args[idx] = in_data[idx][tid + scalar_cal_offset];
}
__device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[],
int tid, int idx) {
auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
args[idx] = in_data[idx][offset];
}
__device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize],
int tid) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) {
InVecType *vector_args = reinterpret_cast<InVecType *>(args[j]);
LoadVectorizedDataCommon(vector_args, tid, j);
} else {
LoadVectorizedDataByDivmod(args[j], tid, j);
}
}
} }
}
__device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) { template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
#pragma unroll(ET) int VecSize, typename Functor, bool IsBoundary = false>
for (int j = 0; j < ET; ++j) { __device__ void DealSegment(
if (no_broadcast[j]) { const framework::Array<const InT *__restrict__, ET> &in, OutT *out,
LoadScalarizedDataCommon(args, tid, j); const framework::Array<bool, MAX_INPUT_NUM> &use_broadcast, uint32_t numel,
} else { const framework::Array<kps::details::BroadcastConfig<ShapeSize>,
LoadScalarizedDataByDivmod(args, tid, j); MAX_INPUT_NUM> &configlists,
} int num, Functor func) {
} InT args[ET][VecSize];
OutT result[VecSize];
int block_offset = blockIdx.x * blockDim.x * VecSize;
// load
#pragma unroll
for (int i = 0; i < ET; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, ShapeSize, IsBoundary>(args[i], in[i], block_offset,
configlists[i], numel, num,
use_broadcast[i]);
} }
// compute
__device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out, if (ET == kUnary) {
int tid) { kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
vec_out_data[tid] = vec_args_out; func);
} else if (ET == kBinary) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
} else {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
} }
// compute
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
num);
}
__device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) { template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
out_data[scalar_cal_offset + tid] = args_out; int VecSize, typename Functor>
__global__ void BroadcastKernel(
framework::Array<const InT *__restrict__, ET> in, OutT *out,
framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
int main_tid, int tail_tid, Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block
if (blockIdx.x < main_tid) {
int num = blockDim.x * VecSize; // blockIdx.x < main_tid
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, false>(
in, out, use_broadcast, numel, configlists, num, func);
} else { // reminder
int num = tail_tid;
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, true>(
in, out, use_broadcast, numel, configlists, num, func);
} }
};
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl(
BroadcastArgsWrapper broadcast_wrapper, int tid) {
InT args[ET];
OutT args_out;
broadcast_wrapper.LoadScalarizedData(args, tid);
// Calcualtion of the in_tensor data.
args_out = broadcast_wrapper.func(args);
broadcast_wrapper.StoreScalarizedData(args_out, tid);
} }
template <typename InT, typename OutT, typename BroadcastArgsWrapper, template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
ElementwiseType ET, int VecSize> int Size, typename Functor>
__device__ inline void VectorizedBroadcastKernelImpl( void LaunchKernel(const platform::CUDADeviceContext &ctx,
BroadcastArgsWrapper broadcast_wrapper, int tid) { const std::vector<const framework::Tensor *> &ins,
using OutVecType = platform::AlignedVector<OutT, VecSize>; framework::Tensor *out, Functor func,
OutVecType args_out; DimensionsTransform merge_dims) {
InT ins[ET]; int numel = out->numel();
InT args[ET][VecSize]; const int threads = 256;
broadcast_wrapper.LoadVectorizedData(args, tid); int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
#pragma unroll(VecSize) int main_tid = numel / (VecSize * threads);
for (int i = 0; i < VecSize; ++i) { int tail_tid = numel % (VecSize * threads);
#pragma unroll(ET) auto stream = ctx.stream();
for (int j = 0; j < ET; ++j) { OutT *out_data = out->data<OutT>();
ins[j] = args[j][i];
framework::Array<kps::details::BroadcastConfig<Size>, MAX_INPUT_NUM>
configlists;
framework::Array<bool, MAX_INPUT_NUM> use_broadcast;
framework::Array<const InT *__restrict__, ET> ins_data;
for (int i = 0; i < ET; i++) {
use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = ins[i]->data<InT>();
if (use_broadcast[i]) {
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configlists[i] = kps::details::BroadcastConfig<Size>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
} }
args_out.val[i] = broadcast_wrapper.func(ins);
} }
broadcast_wrapper.StoreVectorizedData(args_out, tid);
}
template <typename InT, typename OutT, typename BroadcastArgsWrapper, BroadcastKernel<ET, InT, OutT, Size, VecSize,
ElementwiseType ET, int VecSize> Functor><<<blocks, threads, 0, stream>>>(
__global__ void ElementwiseBroadcastKernel( ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid,
BroadcastArgsWrapper broadcast_wrapper, int main_tid, int tail_tid) { func);
int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Vectorized calculation of major data whose length is the max multipler of
// VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if (tid < main_tid) {
VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET, VecSize>(
broadcast_wrapper, tid);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if (tid < tail_tid) {
ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET>(
broadcast_wrapper, tid);
}
} }
template <typename InT, typename OutT, ElementwiseType ET, int VecSize, template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
...@@ -365,98 +280,24 @@ void LaunchBroadcastKernelForDifferentDimSize( ...@@ -365,98 +280,24 @@ void LaunchBroadcastKernelForDifferentDimSize(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out, const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) { int axis, Functor func) {
int numel = out->numel();
int threads = GetThreadsConfig(ctx, numel, VecSize);
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / VecSize;
int tail_tid = numel % VecSize;
int vec_len = main_tid * VecSize;
auto stream = ctx.stream();
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
const auto offset_calculator = StridesCalculation( #define DIM_SIZE(size) \
merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims); case size: { \
LaunchKernel<InT, OutT, ET, VecSize, size, Functor>(ctx, ins, out, func, \
merge_dims); \
} break;
switch (merge_dims.dim_size) { switch (merge_dims.dim_size) {
case 1: { DIM_SIZE(1);
auto broadcast_wrapper = DIM_SIZE(2);
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 1>( DIM_SIZE(3);
ins, out, vec_len, func, offset_calculator); DIM_SIZE(4);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET, DIM_SIZE(5);
VecSize><<<blocks, threads, 0, stream>>>( DIM_SIZE(6);
broadcast_wrapper, main_tid, tail_tid); DIM_SIZE(7);
break; DIM_SIZE(8);
}
case 2: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 2>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 3: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 3>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 4: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 4>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 5: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 5>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 6: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 6>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 7: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 7>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 8: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 8>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.\n",
merge_dims.dim_size, framework::DDim::kMaxRank));
}
} }
#undef DIM_SIZE
} }
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
...@@ -528,5 +369,7 @@ void LaunchElementwiseCudaKernel( ...@@ -528,5 +369,7 @@ void LaunchElementwiseCudaKernel(
} }
} }
#undef MAX_INPUT_NUM
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
...@@ -26,6 +27,7 @@ limitations under the License. */ ...@@ -26,6 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 }; enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
/* /*
...@@ -67,121 +69,74 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins, ...@@ -67,121 +69,74 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
return vec_size; return vec_size;
} }
template <ElementwiseType ET, int VecSize, typename InT, typename OutT> template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
struct ElementwiseDataWrapper { typename Functor, bool IsBoundary>
using InVecType = platform::AlignedVector<InT, VecSize>; __device__ void DealSegment(
using OutVecType = platform::AlignedVector<OutT, VecSize>; const framework::Array<const InT *__restrict__, ET> &in, OutT *out, int num,
Functor func) {
const InT *__restrict__ in_data[ET]; int data_offset = VecSize * blockIdx.x * blockDim.x;
OutT *out_data; InT args[ET][VecSize];
uint32_t scalar_cal_offset; OutT result[VecSize];
// load data
HOSTDEVICE ElementwiseDataWrapper(
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, uint32_t scalar_cal_offset)
: scalar_cal_offset(scalar_cal_offset) {
#pragma unroll
for (int i = 0; i < ET; ++i) {
in_data[i] = ins[i]->data<InT>();
}
out_data = (*outs)[0]->data<OutT>();
}
inline __device__ void LoadVectorizedData(InVecType vec_args[], int tid) {
#pragma unroll
for (int i = 0; i < ET; ++i) {
const InVecType *in_vec_data =
reinterpret_cast<const InVecType *>(in_data[i]);
vec_args[i] = in_vec_data[tid];
}
}
inline __device__ void LoadScalarizedData(InT args[], int tid) {
#pragma unroll #pragma unroll
for (int i = 0; i < ET; ++i) { for (int i = 0; i < ET; i++) {
args[i] = in_data[i][tid + scalar_cal_offset]; kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
} kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
} num);
inline __device__ void StoreVectorizedData(OutVecType res, int tid) {
OutVecType *out_vec = reinterpret_cast<OutVecType *>(out_data);
out_vec[tid] = res;
}
inline __device__ void StoreScalarizedData(OutT res, int tid) {
out_data[tid + scalar_cal_offset] = res;
} }
};
template <ElementwiseType ET, int VecSize, typename ElementwiseWrapper,
typename InT, typename OutT, typename Functor>
__device__ inline void VectorizedKernelImpl(ElementwiseWrapper data,
Functor func, int tid) {
using InVecType = platform::AlignedVector<InT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;
InVecType ins_vec[ET];
OutVecType out_vec;
InT *ins_ptr[ET];
InT ins[ET];
#pragma unroll
for (int i = 0; i < ET; ++i) {
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
}
// load
data.LoadVectorizedData(ins_vec, tid);
// compute // compute
#pragma unroll if (ET == kUnary) {
for (int i = 0; i < VecSize; ++i) { kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
#pragma unroll func);
for (int j = 0; j < ET; ++j) { } else if (ET == kBinary) {
ins[j] = ins_ptr[j][i]; kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
} args[1], func);
out_vec.val[i] = func(ins); } else {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
} }
// store
data.StoreVectorizedData(out_vec, tid);
}
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT,
typename OutT, typename Functor>
__device__ inline void ScalarKernelImpl(ElementwiseWrapper data, Functor func,
int tid) {
InT ins[ET];
OutT out;
// load
data.LoadScalarizedData(ins, tid);
// compute
out = func(ins);
// store // store
data.StoreScalarizedData(out, tid); kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result,
num);
} }
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT, template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename OutT, int VecSize, typename Functor> typename Functor>
__global__ void VectorizedKernel(ElementwiseWrapper data, int main_tid, __global__ void ElementVectorizeKernel(
int tail_tid, Functor func) { framework::Array<const InT *__restrict__, ET> in, OutT *out, int size,
int tid = blockIdx.x * blockDim.x + threadIdx.x; Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
if (tid < main_tid) { int num = size - data_offset;
VectorizedKernelImpl<ET, VecSize, ElementwiseWrapper, InT, OutT, Functor>( // the num this time have to deal with
data, func, tid); if (VecSize * blockDim.x > num) { // reminder segment
} DealSegment<ET, VecSize, InT, OutT, Functor, true>(in, out, num, func);
if (tid < tail_tid) { } else { // complete segment
ScalarKernelImpl<ET, ElementwiseWrapper, InT, OutT, Functor>(data, func, DealSegment<ET, VecSize, InT, OutT, Functor, false>(in, out, num, func);
tid);
} }
} }
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT, template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
typename OutT, typename Functor> int VecSize>
__global__ void ScalarKernel(ElementwiseWrapper data, int numel, Functor func) { void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
int tid = blockIdx.x * blockDim.x + threadIdx.x; const std::vector<const framework::Tensor *> &ins,
if (tid < numel) { std::vector<framework::Tensor *> *outs,
ScalarKernelImpl<ET, ElementwiseWrapper, InT, OutT, Functor>(data, func, Functor func) {
tid); auto numel = ins[0]->numel();
int block_size = GetThreadsConfig(ctx, numel, VecSize);
int grid_size =
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
auto stream = ctx.stream();
OutT *out = (*outs)[0]->data<OutT>();
framework::Array<const InT *__restrict__, ET> in;
for (int i = 0; i < ET; i++) {
in[i] = ins[i]->data<InT>();
} }
ElementVectorizeKernel<ET, VecSize, InT, OutT,
Functor><<<grid_size, block_size, 0, stream>>>(
in, out, numel, func);
} }
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
...@@ -190,43 +145,17 @@ void LaunchSameDimsElementwiseCudaKernel( ...@@ -190,43 +145,17 @@ void LaunchSameDimsElementwiseCudaKernel(
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) { std::vector<framework::Tensor *> *outs, Functor func) {
// calculate the max vec_size for all ins and outs // calculate the max vec_size for all ins and outs
auto numel = ins[0]->numel();
int vec_size = GetVectorizedSizeForIO<InT, OutT>(ins, *outs); int vec_size = GetVectorizedSizeForIO<InT, OutT>(ins, *outs);
int block_size = GetThreadsConfig(ctx, numel, vec_size);
int grid_size =
((numel + vec_size - 1) / vec_size + block_size - 1) / block_size;
int main_tid = numel / vec_size;
int tail_tid = numel % vec_size;
uint32_t vec_len = main_tid * vec_size;
// cuda kernel
auto stream = ctx.stream();
switch (vec_size) { switch (vec_size) {
case 4: { case 4:
auto data_wrapper = ElementwiseCudaKernel<ET, InT, OutT, Functor, 4>(ctx, ins, outs, func);
ElementwiseDataWrapper<ET, 4, InT, OutT>(ins, outs, vec_len);
VectorizedKernel<ET, decltype(data_wrapper), InT, OutT,
4><<<grid_size, block_size, 0, stream>>>(
data_wrapper, main_tid, tail_tid, func);
break; break;
} case 2:
case 2: { ElementwiseCudaKernel<ET, InT, OutT, Functor, 2>(ctx, ins, outs, func);
auto data_wrapper =
ElementwiseDataWrapper<ET, 2, InT, OutT>(ins, outs, vec_len);
VectorizedKernel<ET, decltype(data_wrapper), InT, OutT,
2><<<grid_size, block_size, 0, stream>>>(
data_wrapper, main_tid, tail_tid, func);
break; break;
} case 1:
case 1: { ElementwiseCudaKernel<ET, InT, OutT, Functor, 1>(ctx, ins, outs, func);
auto data_wrapper =
ElementwiseDataWrapper<ET, 1, InT, OutT>(ins, outs, 0);
ScalarKernel<ET, decltype(data_wrapper), InT,
OutT><<<grid_size, block_size, 0, stream>>>(data_wrapper,
numel, func);
break; break;
}
default: { default: {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size)); "Unsupported vectorized size: %d !", vec_size));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册