未验证 提交 eae4bf5b 编写于 作者: N niuliling123 提交者: GitHub

Modify the elementwise op according to the kernel primitive API (#34456)

上级 b211f02b
......@@ -15,10 +15,14 @@
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
namespace paddle {
namespace operators {
#define MAX_INPUT_NUM 3 // the max num of ET for BroadcacstConfig
namespace kps = paddle::operators::kernel_primitives;
struct DimensionsTransform {
using DimVector = std::vector<int64_t>;
typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
......@@ -161,202 +165,113 @@ struct DimensionsTransform {
}
};
struct StridesCalculation {
std::vector<std::vector<uint32_t>> strides;
std::vector<platform::FastDivMod> divmoders;
private:
// To calculate the strides of each input_tensor.
__inline__ void CalculateStrides(
int N, int dim_size, const std::vector<std::vector<int64_t>> &in_dims) {
for (int j = 0; j < N; ++j) {
for (int i = 0; i < dim_size; ++i) {
strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i];
strides[j][i] =
(i != 0 && strides[j][i] != 0)
? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1,
std::multiplies<int64_t>())
: strides[j][i];
}
}
}
public:
explicit StridesCalculation(const int64_t &dim_size,
const std::vector<std::vector<int64_t>> &in_dims,
const std::vector<int64_t> &out_dims) {
const auto N = in_dims.size();
divmoders.resize(dim_size);
strides.resize(N, std::vector<uint32_t>(dim_size, 1));
for (int i = 0; i < dim_size; ++i) {
divmoders[i] = platform::FastDivMod(out_dims[i]);
}
CalculateStrides(N, dim_size, in_dims);
}
};
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int VecSize, int kDims>
struct BroadcastArgsWrapper {
using InVecType = platform::AlignedVector<InT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;
OutT *out_data;
OutVecType *vec_out_data;
const InT *__restrict__ in_data[ET];
const InVecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET];
platform::FastDivMod divmoders[kDims];
uint32_t strides[ET][framework::DDim::kMaxRank];
uint32_t scalar_cal_offset;
Functor func;
HOSTDEVICE BroadcastArgsWrapper(
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int scalar_cal_offset, Functor func,
const StridesCalculation &offset_calculator)
: scalar_cal_offset(scalar_cal_offset), func(func) {
for (int j = 0; j < ET; ++j) {
in_data[j] = ins[j]->data<InT>();
vec_in_data[j] = reinterpret_cast<const InVecType *>(in_data[j]);
no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
memcpy(strides[j], offset_calculator.strides[j].data(),
kDims * sizeof(uint32_t));
}
out_data = out->data<OutT>();
vec_out_data = reinterpret_cast<OutVecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(platform::FastDivMod));
}
__device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
uint32_t offset = 0;
#pragma unroll(kDims)
for (int i = 0; i < kDims; ++i) {
auto fast_divmoder = divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
offset += fast_divmoder.val[1] * strides[in_idx][i];
}
return offset;
}
__device__ __forceinline__ void LoadVectorizedDataCommon(
InVecType *vector_args, int tid, int idx) {
*vector_args = vec_in_data[idx][tid];
}
__device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args,
int tid, int idx) {
int index = tid * VecSize;
#pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) {
uint32_t offset = GetOffsetByDivmod(index + i, idx);
scalar_args[i] = in_data[idx][offset];
}
}
__device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid,
int idx) {
args[idx] = in_data[idx][tid + scalar_cal_offset];
}
__device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[],
int tid, int idx) {
auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
args[idx] = in_data[idx][offset];
}
__device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize],
int tid) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) {
InVecType *vector_args = reinterpret_cast<InVecType *>(args[j]);
LoadVectorizedDataCommon(vector_args, tid, j);
template <typename T, int VecSize, int ShapeSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst, const T *__restrict__ src, uint32_t block_offset,
const kps::details::BroadcastConfig<ShapeSize> &config, int numel, int num,
bool need_broadcast) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, ShapeSize, IsBoundary>(
dst, src, block_offset, config, numel, 1, 1);
} else {
LoadVectorizedDataByDivmod(args[j], tid, j);
}
}
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
}
}
__device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) {
LoadScalarizedDataCommon(args, tid, j);
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
int VecSize, typename Functor, bool IsBoundary = false>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, ET> &in, OutT *out,
const framework::Array<bool, MAX_INPUT_NUM> &use_broadcast, uint32_t numel,
const framework::Array<kps::details::BroadcastConfig<ShapeSize>,
MAX_INPUT_NUM> &configlists,
int num, Functor func) {
InT args[ET][VecSize];
OutT result[VecSize];
int block_offset = blockIdx.x * blockDim.x * VecSize;
// load
#pragma unroll
for (int i = 0; i < ET; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, ShapeSize, IsBoundary>(args[i], in[i], block_offset,
configlists[i], numel, num,
use_broadcast[i]);
}
// compute
if (ET == kUnary) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
func);
} else if (ET == kBinary) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
} else {
LoadScalarizedDataByDivmod(args, tid, j);
}
}
}
__device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out,
int tid) {
vec_out_data[tid] = vec_args_out;
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
// compute
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
num);
}
__device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) {
out_data[scalar_cal_offset + tid] = args_out;
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
int VecSize, typename Functor>
__global__ void BroadcastKernel(
framework::Array<const InT *__restrict__, ET> in, OutT *out,
framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
int main_tid, int tail_tid, Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block
if (blockIdx.x < main_tid) {
int num = blockDim.x * VecSize; // blockIdx.x < main_tid
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, false>(
in, out, use_broadcast, numel, configlists, num, func);
} else { // reminder
int num = tail_tid;
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, true>(
in, out, use_broadcast, numel, configlists, num, func);
}
};
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl(
BroadcastArgsWrapper broadcast_wrapper, int tid) {
InT args[ET];
OutT args_out;
broadcast_wrapper.LoadScalarizedData(args, tid);
}
// Calcualtion of the in_tensor data.
args_out = broadcast_wrapper.func(args);
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
int Size, typename Functor>
void LaunchKernel(const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
framework::Tensor *out, Functor func,
DimensionsTransform merge_dims) {
int numel = out->numel();
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
broadcast_wrapper.StoreScalarizedData(args_out, tid);
}
int main_tid = numel / (VecSize * threads);
int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream();
OutT *out_data = out->data<OutT>();
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET, int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWrapper broadcast_wrapper, int tid) {
using OutVecType = platform::AlignedVector<OutT, VecSize>;
OutVecType args_out;
InT ins[ET];
InT args[ET][VecSize];
broadcast_wrapper.LoadVectorizedData(args, tid);
framework::Array<kps::details::BroadcastConfig<Size>, MAX_INPUT_NUM>
configlists;
framework::Array<bool, MAX_INPUT_NUM> use_broadcast;
framework::Array<const InT *__restrict__, ET> ins_data;
#pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
ins[j] = args[j][i];
for (int i = 0; i < ET; i++) {
use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = ins[i]->data<InT>();
if (use_broadcast[i]) {
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configlists[i] = kps::details::BroadcastConfig<Size>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
args_out.val[i] = broadcast_wrapper.func(ins);
}
broadcast_wrapper.StoreVectorizedData(args_out, tid);
}
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET, int VecSize>
__global__ void ElementwiseBroadcastKernel(
BroadcastArgsWrapper broadcast_wrapper, int main_tid, int tail_tid) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Vectorized calculation of major data whose length is the max multipler of
// VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if (tid < main_tid) {
VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET, VecSize>(
broadcast_wrapper, tid);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if (tid < tail_tid) {
ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET>(
broadcast_wrapper, tid);
}
BroadcastKernel<ET, InT, OutT, Size, VecSize,
Functor><<<blocks, threads, 0, stream>>>(
ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid,
func);
}
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
......@@ -365,98 +280,24 @@ void LaunchBroadcastKernelForDifferentDimSize(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) {
int numel = out->numel();
int threads = GetThreadsConfig(ctx, numel, VecSize);
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / VecSize;
int tail_tid = numel % VecSize;
int vec_len = main_tid * VecSize;
auto stream = ctx.stream();
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
const auto offset_calculator = StridesCalculation(
merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims);
#define DIM_SIZE(size) \
case size: { \
LaunchKernel<InT, OutT, ET, VecSize, size, Functor>(ctx, ins, out, func, \
merge_dims); \
} break;
switch (merge_dims.dim_size) {
case 1: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 1>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 2: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 2>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 3: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 3>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 4: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 4>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 5: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 5>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 6: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 6>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 7: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 7>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 8: {
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 8>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_wrapper, main_tid, tail_tid);
break;
}
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.\n",
merge_dims.dim_size, framework::DDim::kMaxRank));
}
}
DIM_SIZE(1);
DIM_SIZE(2);
DIM_SIZE(3);
DIM_SIZE(4);
DIM_SIZE(5);
DIM_SIZE(6);
DIM_SIZE(7);
DIM_SIZE(8);
}
#undef DIM_SIZE
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
......@@ -528,5 +369,7 @@ void LaunchElementwiseCudaKernel(
}
}
#undef MAX_INPUT_NUM
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
......@@ -26,6 +27,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
/*
......@@ -67,121 +69,74 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
return vec_size;
}
template <ElementwiseType ET, int VecSize, typename InT, typename OutT>
struct ElementwiseDataWrapper {
using InVecType = platform::AlignedVector<InT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;
const InT *__restrict__ in_data[ET];
OutT *out_data;
uint32_t scalar_cal_offset;
HOSTDEVICE ElementwiseDataWrapper(
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, uint32_t scalar_cal_offset)
: scalar_cal_offset(scalar_cal_offset) {
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor, bool IsBoundary>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, ET> &in, OutT *out, int num,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
InT args[ET][VecSize];
OutT result[VecSize];
// load data
#pragma unroll
for (int i = 0; i < ET; ++i) {
in_data[i] = ins[i]->data<InT>();
}
out_data = (*outs)[0]->data<OutT>();
for (int i = 0; i < ET; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
num);
}
inline __device__ void LoadVectorizedData(InVecType vec_args[], int tid) {
#pragma unroll
for (int i = 0; i < ET; ++i) {
const InVecType *in_vec_data =
reinterpret_cast<const InVecType *>(in_data[i]);
vec_args[i] = in_vec_data[tid];
}
}
inline __device__ void LoadScalarizedData(InT args[], int tid) {
#pragma unroll
for (int i = 0; i < ET; ++i) {
args[i] = in_data[i][tid + scalar_cal_offset];
}
}
inline __device__ void StoreVectorizedData(OutVecType res, int tid) {
OutVecType *out_vec = reinterpret_cast<OutVecType *>(out_data);
out_vec[tid] = res;
}
inline __device__ void StoreScalarizedData(OutT res, int tid) {
out_data[tid + scalar_cal_offset] = res;
}
};
template <ElementwiseType ET, int VecSize, typename ElementwiseWrapper,
typename InT, typename OutT, typename Functor>
__device__ inline void VectorizedKernelImpl(ElementwiseWrapper data,
Functor func, int tid) {
using InVecType = platform::AlignedVector<InT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;
InVecType ins_vec[ET];
OutVecType out_vec;
InT *ins_ptr[ET];
InT ins[ET];
#pragma unroll
for (int i = 0; i < ET; ++i) {
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
}
// load
data.LoadVectorizedData(ins_vec, tid);
// compute
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
#pragma unroll
for (int j = 0; j < ET; ++j) {
ins[j] = ins_ptr[j][i];
}
out_vec.val[i] = func(ins);
// compute
if (ET == kUnary) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
func);
} else if (ET == kBinary) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
} else {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
// store
data.StoreVectorizedData(out_vec, tid);
}
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT,
typename OutT, typename Functor>
__device__ inline void ScalarKernelImpl(ElementwiseWrapper data, Functor func,
int tid) {
InT ins[ET];
OutT out;
// load
data.LoadScalarizedData(ins, tid);
// compute
out = func(ins);
// store
data.StoreScalarizedData(out, tid);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result,
num);
}
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT,
typename OutT, int VecSize, typename Functor>
__global__ void VectorizedKernel(ElementwiseWrapper data, int main_tid,
int tail_tid, Functor func) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < main_tid) {
VectorizedKernelImpl<ET, VecSize, ElementwiseWrapper, InT, OutT, Functor>(
data, func, tid);
}
if (tid < tail_tid) {
ScalarKernelImpl<ET, ElementwiseWrapper, InT, OutT, Functor>(data, func,
tid);
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__global__ void ElementVectorizeKernel(
framework::Array<const InT *__restrict__, ET> in, OutT *out, int size,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
int num = size - data_offset;
// the num this time have to deal with
if (VecSize * blockDim.x > num) { // reminder segment
DealSegment<ET, VecSize, InT, OutT, Functor, true>(in, out, num, func);
} else { // complete segment
DealSegment<ET, VecSize, InT, OutT, Functor, false>(in, out, num, func);
}
}
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT,
typename OutT, typename Functor>
__global__ void ScalarKernel(ElementwiseWrapper data, int numel, Functor func) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < numel) {
ScalarKernelImpl<ET, ElementwiseWrapper, InT, OutT, Functor>(data, func,
tid);
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int VecSize>
void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs,
Functor func) {
auto numel = ins[0]->numel();
int block_size = GetThreadsConfig(ctx, numel, VecSize);
int grid_size =
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
auto stream = ctx.stream();
OutT *out = (*outs)[0]->data<OutT>();
framework::Array<const InT *__restrict__, ET> in;
for (int i = 0; i < ET; i++) {
in[i] = ins[i]->data<InT>();
}
ElementVectorizeKernel<ET, VecSize, InT, OutT,
Functor><<<grid_size, block_size, 0, stream>>>(
in, out, numel, func);
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
......@@ -190,43 +145,17 @@ void LaunchSameDimsElementwiseCudaKernel(
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) {
// calculate the max vec_size for all ins and outs
auto numel = ins[0]->numel();
int vec_size = GetVectorizedSizeForIO<InT, OutT>(ins, *outs);
int block_size = GetThreadsConfig(ctx, numel, vec_size);
int grid_size =
((numel + vec_size - 1) / vec_size + block_size - 1) / block_size;
int main_tid = numel / vec_size;
int tail_tid = numel % vec_size;
uint32_t vec_len = main_tid * vec_size;
// cuda kernel
auto stream = ctx.stream();
switch (vec_size) {
case 4: {
auto data_wrapper =
ElementwiseDataWrapper<ET, 4, InT, OutT>(ins, outs, vec_len);
VectorizedKernel<ET, decltype(data_wrapper), InT, OutT,
4><<<grid_size, block_size, 0, stream>>>(
data_wrapper, main_tid, tail_tid, func);
case 4:
ElementwiseCudaKernel<ET, InT, OutT, Functor, 4>(ctx, ins, outs, func);
break;
}
case 2: {
auto data_wrapper =
ElementwiseDataWrapper<ET, 2, InT, OutT>(ins, outs, vec_len);
VectorizedKernel<ET, decltype(data_wrapper), InT, OutT,
2><<<grid_size, block_size, 0, stream>>>(
data_wrapper, main_tid, tail_tid, func);
case 2:
ElementwiseCudaKernel<ET, InT, OutT, Functor, 2>(ctx, ins, outs, func);
break;
}
case 1: {
auto data_wrapper =
ElementwiseDataWrapper<ET, 1, InT, OutT>(ins, outs, 0);
ScalarKernel<ET, decltype(data_wrapper), InT,
OutT><<<grid_size, block_size, 0, stream>>>(data_wrapper,
numel, func);
case 1:
ElementwiseCudaKernel<ET, InT, OutT, Functor, 1>(ctx, ins, outs, func);
break;
}
default: {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册