未验证 提交 d6aea4ac 编写于 作者: L limingshu 提交者: GitHub

Support OutType tmeplate argument in elementwise_broadcast branch (#33060)

上级 a6dc68b7
......@@ -196,15 +196,16 @@ struct StridesCalculation {
}
};
template <typename T, typename Functor, ElementwiseType ET, int VecSize,
int kDims>
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int VecSize, int kDims>
struct BroadcastArgsWarpper {
using VecType = CudaAlignedVector<T, VecSize>;
using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
T *out_data;
VecType *vec_out_data;
const T *__restrict__ in_data[ET];
const VecType *__restrict__ vec_in_data[ET];
OutT *out_data;
OutVecType *vec_out_data;
const InT *__restrict__ in_data[ET];
const InVecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET];
FastDivMod divmoders[kDims];
uint32_t strides[ET][framework::DDim::kMaxRank];
......@@ -217,14 +218,14 @@ struct BroadcastArgsWarpper {
const StridesCalculation &offset_calculator)
: scalar_cal_offset(scalar_cal_offset), func(func) {
for (int j = 0; j < ET; ++j) {
in_data[j] = ins[j]->data<T>();
vec_in_data[j] = reinterpret_cast<const VecType *>(in_data[j]);
in_data[j] = ins[j]->data<InT>();
vec_in_data[j] = reinterpret_cast<const InVecType *>(in_data[j]);
no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
memcpy(strides[j], offset_calculator.strides[j].data(),
kDims * sizeof(uint32_t));
}
out_data = out->data<T>();
vec_out_data = reinterpret_cast<VecType *>(out_data);
out_data = out->data<OutT>();
vec_out_data = reinterpret_cast<OutVecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(FastDivMod));
}
......@@ -241,12 +242,12 @@ struct BroadcastArgsWarpper {
return offset;
}
__device__ __forceinline__ void LoadVectorizedDataCommon(VecType *vector_args,
int tid, int idx) {
__device__ __forceinline__ void LoadVectorizedDataCommon(
InVecType *vector_args, int tid, int idx) {
*vector_args = vec_in_data[idx][tid];
}
__device__ __forceinline__ void LoadVectorizedDataByDivmod(T *scalar_args,
__device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args,
int tid, int idx) {
int index = tid * VecSize;
#pragma unroll(VecSize)
......@@ -256,23 +257,23 @@ struct BroadcastArgsWarpper {
}
}
__device__ __forceinline__ void LoadScalarizedDataCommon(T args[], int tid,
__device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid,
int idx) {
args[idx] = in_data[idx][tid + scalar_cal_offset];
}
__device__ __forceinline__ void LoadScalarizedDataByDivmod(T args[], int tid,
int idx) {
__device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[],
int tid, int idx) {
auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
args[idx] = in_data[idx][offset];
}
__device__ __forceinline__ void LoadVectorizedData(T (*args)[VecSize],
__device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize],
int tid) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) {
VecType *vector_args = reinterpret_cast<VecType *>(args[j]);
InVecType *vector_args = reinterpret_cast<InVecType *>(args[j]);
LoadVectorizedDataCommon(vector_args, tid, j);
} else {
LoadVectorizedDataByDivmod(args[j], tid, j);
......@@ -280,7 +281,7 @@ struct BroadcastArgsWarpper {
}
}
__device__ __forceinline__ void LoadScalarizedData(T args[], int tid) {
__device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) {
......@@ -291,36 +292,39 @@ struct BroadcastArgsWarpper {
}
}
__device__ __forceinline__ void StoreVectorizedData(T (*args)[VecSize],
__device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out,
int tid) {
VecType *args_out = reinterpret_cast<VecType *>(args[0]);
vec_out_data[tid] = *args_out;
vec_out_data[tid] = vec_args_out;
}
__device__ __forceinline__ void StoreScalarizedData(T args[], int tid) {
out_data[scalar_cal_offset + tid] = args[0];
__device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) {
out_data[scalar_cal_offset + tid] = args_out;
}
};
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET>
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
T args[ET];
InT args[ET];
OutT args_out;
broadcast_warpper.LoadScalarizedData(args, tid);
#pragma unroll(ET)
for (int j = 1; j < ET; ++j) {
args[0] = broadcast_warpper.func(args);
args_out = broadcast_warpper.func(args);
}
broadcast_warpper.StoreScalarizedData(args, tid);
broadcast_warpper.StoreScalarizedData(args_out, tid);
}
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET,
int VecSize>
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
ElementwiseType ET, int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
T ins[ET];
T args[ET][VecSize];
using OutVecType = CudaAlignedVector<OutT, VecSize>;
OutVecType args_out;
InT ins[ET];
InT args[ET][VecSize];
broadcast_warpper.LoadVectorizedData(args, tid);
#pragma unroll(VecSize)
......@@ -329,13 +333,13 @@ __device__ inline void VectorizedBroadcastKernelImpl(
for (int j = 0; j < ET; ++j) {
ins[j] = args[j][i];
}
args[0][i] = broadcast_warpper.func(ins);
args_out.val[i] = broadcast_warpper.func(ins);
}
broadcast_warpper.StoreVectorizedData(args, tid);
broadcast_warpper.StoreVectorizedData(args_out, tid);
}
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET,
int VecSize>
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
ElementwiseType ET, int VecSize>
__global__ void ElementwiseBroadcastKernel(
BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
......@@ -345,19 +349,20 @@ __global__ void ElementwiseBroadcastKernel(
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if (tid < main_tid) {
VectorizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET, VecSize>(
VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET, VecSize>(
broadcast_warpper, tid);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if (tid < tail_tid) {
ScalarizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET>(
ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET>(
broadcast_warpper, tid);
}
}
template <typename T, ElementwiseType ET, int VecSize, typename Functor>
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
typename Functor>
void LaunchBroadcastKernelForDifferentDimSize(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
......@@ -376,65 +381,73 @@ void LaunchBroadcastKernelForDifferentDimSize(
switch (merge_dims.dim_size) {
case 1: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 1>(
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 1>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 2: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 2>(
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 2>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 3: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 3>(
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 3>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 4: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 4>(
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 4>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 5: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 5>(
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 5>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 6: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 6>(
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 6>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 7: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 7>(
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 7>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 8: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 8>(
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 8>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
......@@ -448,7 +461,7 @@ void LaunchBroadcastKernelForDifferentDimSize(
}
}
template <ElementwiseType ET, typename T, typename Functor>
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
......@@ -457,27 +470,27 @@ void LaunchBroadcastElementwiseCudaKernel(
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<T>(in->data<T>());
auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size = GetVectorizedSizeImpl<T>(out->data<T>());
int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);
switch (vec_size) {
case 4: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 4>(ctx, ins, out, axis,
func);
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
axis, func);
break;
}
case 2: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 2>(ctx, ins, out, axis,
func);
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
axis, func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 1>(ctx, ins, out, axis,
func);
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
axis, func);
break;
}
default: {
......@@ -502,8 +515,9 @@ void LaunchElementwiseCudaKernel(
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>(
cuda_ctx, ins, outs, func);
} else {
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT>(
cuda_ctx, ins, outs, axis, func);
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT,
OutType>(cuda_ctx, ins, outs, axis,
func);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册