未验证 提交 d6aea4ac 编写于 作者: L limingshu 提交者: GitHub

Support OutType tmeplate argument in elementwise_broadcast branch (#33060)

上级 a6dc68b7
...@@ -196,15 +196,16 @@ struct StridesCalculation { ...@@ -196,15 +196,16 @@ struct StridesCalculation {
} }
}; };
template <typename T, typename Functor, ElementwiseType ET, int VecSize, template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int kDims> int VecSize, int kDims>
struct BroadcastArgsWarpper { struct BroadcastArgsWarpper {
using VecType = CudaAlignedVector<T, VecSize>; using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
T *out_data; OutT *out_data;
VecType *vec_out_data; OutVecType *vec_out_data;
const T *__restrict__ in_data[ET]; const InT *__restrict__ in_data[ET];
const VecType *__restrict__ vec_in_data[ET]; const InVecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET]; bool no_broadcast[ET];
FastDivMod divmoders[kDims]; FastDivMod divmoders[kDims];
uint32_t strides[ET][framework::DDim::kMaxRank]; uint32_t strides[ET][framework::DDim::kMaxRank];
...@@ -217,14 +218,14 @@ struct BroadcastArgsWarpper { ...@@ -217,14 +218,14 @@ struct BroadcastArgsWarpper {
const StridesCalculation &offset_calculator) const StridesCalculation &offset_calculator)
: scalar_cal_offset(scalar_cal_offset), func(func) { : scalar_cal_offset(scalar_cal_offset), func(func) {
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
in_data[j] = ins[j]->data<T>(); in_data[j] = ins[j]->data<InT>();
vec_in_data[j] = reinterpret_cast<const VecType *>(in_data[j]); vec_in_data[j] = reinterpret_cast<const InVecType *>(in_data[j]);
no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false; no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
memcpy(strides[j], offset_calculator.strides[j].data(), memcpy(strides[j], offset_calculator.strides[j].data(),
kDims * sizeof(uint32_t)); kDims * sizeof(uint32_t));
} }
out_data = out->data<T>(); out_data = out->data<OutT>();
vec_out_data = reinterpret_cast<VecType *>(out_data); vec_out_data = reinterpret_cast<OutVecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(), memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(FastDivMod)); kDims * sizeof(FastDivMod));
} }
...@@ -241,12 +242,12 @@ struct BroadcastArgsWarpper { ...@@ -241,12 +242,12 @@ struct BroadcastArgsWarpper {
return offset; return offset;
} }
__device__ __forceinline__ void LoadVectorizedDataCommon(VecType *vector_args, __device__ __forceinline__ void LoadVectorizedDataCommon(
int tid, int idx) { InVecType *vector_args, int tid, int idx) {
*vector_args = vec_in_data[idx][tid]; *vector_args = vec_in_data[idx][tid];
} }
__device__ __forceinline__ void LoadVectorizedDataByDivmod(T *scalar_args, __device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args,
int tid, int idx) { int tid, int idx) {
int index = tid * VecSize; int index = tid * VecSize;
#pragma unroll(VecSize) #pragma unroll(VecSize)
...@@ -256,23 +257,23 @@ struct BroadcastArgsWarpper { ...@@ -256,23 +257,23 @@ struct BroadcastArgsWarpper {
} }
} }
__device__ __forceinline__ void LoadScalarizedDataCommon(T args[], int tid, __device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid,
int idx) { int idx) {
args[idx] = in_data[idx][tid + scalar_cal_offset]; args[idx] = in_data[idx][tid + scalar_cal_offset];
} }
__device__ __forceinline__ void LoadScalarizedDataByDivmod(T args[], int tid, __device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[],
int idx) { int tid, int idx) {
auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx); auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
args[idx] = in_data[idx][offset]; args[idx] = in_data[idx][offset];
} }
__device__ __forceinline__ void LoadVectorizedData(T (*args)[VecSize], __device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize],
int tid) { int tid) {
#pragma unroll(ET) #pragma unroll(ET)
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) { if (no_broadcast[j]) {
VecType *vector_args = reinterpret_cast<VecType *>(args[j]); InVecType *vector_args = reinterpret_cast<InVecType *>(args[j]);
LoadVectorizedDataCommon(vector_args, tid, j); LoadVectorizedDataCommon(vector_args, tid, j);
} else { } else {
LoadVectorizedDataByDivmod(args[j], tid, j); LoadVectorizedDataByDivmod(args[j], tid, j);
...@@ -280,7 +281,7 @@ struct BroadcastArgsWarpper { ...@@ -280,7 +281,7 @@ struct BroadcastArgsWarpper {
} }
} }
__device__ __forceinline__ void LoadScalarizedData(T args[], int tid) { __device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) {
#pragma unroll(ET) #pragma unroll(ET)
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) { if (no_broadcast[j]) {
...@@ -291,36 +292,39 @@ struct BroadcastArgsWarpper { ...@@ -291,36 +292,39 @@ struct BroadcastArgsWarpper {
} }
} }
__device__ __forceinline__ void StoreVectorizedData(T (*args)[VecSize], __device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out,
int tid) { int tid) {
VecType *args_out = reinterpret_cast<VecType *>(args[0]); vec_out_data[tid] = vec_args_out;
vec_out_data[tid] = *args_out;
} }
__device__ __forceinline__ void StoreScalarizedData(T args[], int tid) { __device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) {
out_data[scalar_cal_offset + tid] = args[0]; out_data[scalar_cal_offset + tid] = args_out;
} }
}; };
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET> template <typename InT, typename OutT, typename BroadcastArgsWarpper,
ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl( __device__ inline void ScalarizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) { BroadcastArgsWarpper broadcast_warpper, int tid) {
T args[ET]; InT args[ET];
OutT args_out;
broadcast_warpper.LoadScalarizedData(args, tid); broadcast_warpper.LoadScalarizedData(args, tid);
#pragma unroll(ET) #pragma unroll(ET)
for (int j = 1; j < ET; ++j) { for (int j = 1; j < ET; ++j) {
args[0] = broadcast_warpper.func(args); args_out = broadcast_warpper.func(args);
} }
broadcast_warpper.StoreScalarizedData(args, tid); broadcast_warpper.StoreScalarizedData(args_out, tid);
} }
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET, template <typename InT, typename OutT, typename BroadcastArgsWarpper,
int VecSize> ElementwiseType ET, int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl( __device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) { BroadcastArgsWarpper broadcast_warpper, int tid) {
T ins[ET]; using OutVecType = CudaAlignedVector<OutT, VecSize>;
T args[ET][VecSize]; OutVecType args_out;
InT ins[ET];
InT args[ET][VecSize];
broadcast_warpper.LoadVectorizedData(args, tid); broadcast_warpper.LoadVectorizedData(args, tid);
#pragma unroll(VecSize) #pragma unroll(VecSize)
...@@ -329,13 +333,13 @@ __device__ inline void VectorizedBroadcastKernelImpl( ...@@ -329,13 +333,13 @@ __device__ inline void VectorizedBroadcastKernelImpl(
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
ins[j] = args[j][i]; ins[j] = args[j][i];
} }
args[0][i] = broadcast_warpper.func(ins); args_out.val[i] = broadcast_warpper.func(ins);
} }
broadcast_warpper.StoreVectorizedData(args, tid); broadcast_warpper.StoreVectorizedData(args_out, tid);
} }
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET, template <typename InT, typename OutT, typename BroadcastArgsWarpper,
int VecSize> ElementwiseType ET, int VecSize>
__global__ void ElementwiseBroadcastKernel( __global__ void ElementwiseBroadcastKernel(
BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) { BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
...@@ -345,19 +349,20 @@ __global__ void ElementwiseBroadcastKernel( ...@@ -345,19 +349,20 @@ __global__ void ElementwiseBroadcastKernel(
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize // eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4. // is 4.
if (tid < main_tid) { if (tid < main_tid) {
VectorizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET, VecSize>( VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET, VecSize>(
broadcast_warpper, tid); broadcast_warpper, tid);
} }
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize. // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4. // 4.
if (tid < tail_tid) { if (tid < tail_tid) {
ScalarizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET>( ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET>(
broadcast_warpper, tid); broadcast_warpper, tid);
} }
} }
template <typename T, ElementwiseType ET, int VecSize, typename Functor> template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
typename Functor>
void LaunchBroadcastKernelForDifferentDimSize( void LaunchBroadcastKernelForDifferentDimSize(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out, const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
...@@ -376,65 +381,73 @@ void LaunchBroadcastKernelForDifferentDimSize( ...@@ -376,65 +381,73 @@ void LaunchBroadcastKernelForDifferentDimSize(
switch (merge_dims.dim_size) { switch (merge_dims.dim_size) {
case 1: { case 1: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 1>( auto broadcast_warpper =
ins, out, vec_len, func, offset_calculator); BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 1>(
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET, ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 2: { case 2: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 2>( auto broadcast_warpper =
ins, out, vec_len, func, offset_calculator); BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 2>(
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET, ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 3: { case 3: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 3>( auto broadcast_warpper =
ins, out, vec_len, func, offset_calculator); BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 3>(
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET, ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 4: { case 4: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 4>( auto broadcast_warpper =
ins, out, vec_len, func, offset_calculator); BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 4>(
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET, ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 5: { case 5: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 5>( auto broadcast_warpper =
ins, out, vec_len, func, offset_calculator); BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 5>(
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET, ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 6: { case 6: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 6>( auto broadcast_warpper =
ins, out, vec_len, func, offset_calculator); BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 6>(
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET, ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 7: { case 7: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 7>( auto broadcast_warpper =
ins, out, vec_len, func, offset_calculator); BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 7>(
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET, ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 8: { case 8: {
auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 8>( auto broadcast_warpper =
ins, out, vec_len, func, offset_calculator); BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 8>(
ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET, ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
...@@ -448,7 +461,7 @@ void LaunchBroadcastKernelForDifferentDimSize( ...@@ -448,7 +461,7 @@ void LaunchBroadcastKernelForDifferentDimSize(
} }
} }
template <ElementwiseType ET, typename T, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchBroadcastElementwiseCudaKernel( void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
...@@ -457,27 +470,27 @@ void LaunchBroadcastElementwiseCudaKernel( ...@@ -457,27 +470,27 @@ void LaunchBroadcastElementwiseCudaKernel(
int in_vec_size = 4; int in_vec_size = 4;
framework::Tensor *out = (*outs)[0]; framework::Tensor *out = (*outs)[0];
for (auto *in : ins) { for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<T>(in->data<T>()); auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size; : in_vec_size;
} }
int out_vec_size = GetVectorizedSizeImpl<T>(out->data<T>()); int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size); int vec_size = std::min(out_vec_size, in_vec_size);
switch (vec_size) { switch (vec_size) {
case 4: { case 4: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 4>(ctx, ins, out, axis, LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
func); axis, func);
break; break;
} }
case 2: { case 2: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 2>(ctx, ins, out, axis, LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
func); axis, func);
break; break;
} }
case 1: { case 1: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 1>(ctx, ins, out, axis, LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
func); axis, func);
break; break;
} }
default: { default: {
...@@ -502,8 +515,9 @@ void LaunchElementwiseCudaKernel( ...@@ -502,8 +515,9 @@ void LaunchElementwiseCudaKernel(
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>( LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>(
cuda_ctx, ins, outs, func); cuda_ctx, ins, outs, func);
} else { } else {
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT>( LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT,
cuda_ctx, ins, outs, axis, func); OutType>(cuda_ctx, ins, outs, axis,
func);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册