diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index aeef6ee71446f7dec695e79504def1ae13ceddee..1492fc629457cd5f7ca312b452ccd79ab30f175d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -196,15 +196,16 @@ struct StridesCalculation { } }; -template +template struct BroadcastArgsWarpper { - using VecType = CudaAlignedVector; + using InVecType = CudaAlignedVector; + using OutVecType = CudaAlignedVector; - T *out_data; - VecType *vec_out_data; - const T *__restrict__ in_data[ET]; - const VecType *__restrict__ vec_in_data[ET]; + OutT *out_data; + OutVecType *vec_out_data; + const InT *__restrict__ in_data[ET]; + const InVecType *__restrict__ vec_in_data[ET]; bool no_broadcast[ET]; FastDivMod divmoders[kDims]; uint32_t strides[ET][framework::DDim::kMaxRank]; @@ -217,14 +218,14 @@ struct BroadcastArgsWarpper { const StridesCalculation &offset_calculator) : scalar_cal_offset(scalar_cal_offset), func(func) { for (int j = 0; j < ET; ++j) { - in_data[j] = ins[j]->data(); - vec_in_data[j] = reinterpret_cast(in_data[j]); + in_data[j] = ins[j]->data(); + vec_in_data[j] = reinterpret_cast(in_data[j]); no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false; memcpy(strides[j], offset_calculator.strides[j].data(), kDims * sizeof(uint32_t)); } - out_data = out->data(); - vec_out_data = reinterpret_cast(out_data); + out_data = out->data(); + vec_out_data = reinterpret_cast(out_data); memcpy(divmoders, offset_calculator.divmoders.data(), kDims * sizeof(FastDivMod)); } @@ -241,12 +242,12 @@ struct BroadcastArgsWarpper { return offset; } - __device__ __forceinline__ void LoadVectorizedDataCommon(VecType *vector_args, - int tid, int idx) { + __device__ __forceinline__ void LoadVectorizedDataCommon( + InVecType *vector_args, int tid, int idx) { *vector_args = vec_in_data[idx][tid]; } - __device__ __forceinline__ void LoadVectorizedDataByDivmod(T *scalar_args, + __device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args, int tid, int idx) { int index = tid * VecSize; #pragma unroll(VecSize) @@ -256,23 +257,23 @@ struct BroadcastArgsWarpper { } } - __device__ __forceinline__ void LoadScalarizedDataCommon(T args[], int tid, + __device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid, int idx) { args[idx] = in_data[idx][tid + scalar_cal_offset]; } - __device__ __forceinline__ void LoadScalarizedDataByDivmod(T args[], int tid, - int idx) { + __device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[], + int tid, int idx) { auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx); args[idx] = in_data[idx][offset]; } - __device__ __forceinline__ void LoadVectorizedData(T (*args)[VecSize], + __device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize], int tid) { #pragma unroll(ET) for (int j = 0; j < ET; ++j) { if (no_broadcast[j]) { - VecType *vector_args = reinterpret_cast(args[j]); + InVecType *vector_args = reinterpret_cast(args[j]); LoadVectorizedDataCommon(vector_args, tid, j); } else { LoadVectorizedDataByDivmod(args[j], tid, j); @@ -280,7 +281,7 @@ struct BroadcastArgsWarpper { } } - __device__ __forceinline__ void LoadScalarizedData(T args[], int tid) { + __device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) { #pragma unroll(ET) for (int j = 0; j < ET; ++j) { if (no_broadcast[j]) { @@ -291,36 +292,39 @@ struct BroadcastArgsWarpper { } } - __device__ __forceinline__ void StoreVectorizedData(T (*args)[VecSize], + __device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out, int tid) { - VecType *args_out = reinterpret_cast(args[0]); - vec_out_data[tid] = *args_out; + vec_out_data[tid] = vec_args_out; } - __device__ __forceinline__ void StoreScalarizedData(T args[], int tid) { - out_data[scalar_cal_offset + tid] = args[0]; + __device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) { + out_data[scalar_cal_offset + tid] = args_out; } }; -template +template __device__ inline void ScalarizedBroadcastKernelImpl( BroadcastArgsWarpper broadcast_warpper, int tid) { - T args[ET]; + InT args[ET]; + OutT args_out; broadcast_warpper.LoadScalarizedData(args, tid); #pragma unroll(ET) for (int j = 1; j < ET; ++j) { - args[0] = broadcast_warpper.func(args); + args_out = broadcast_warpper.func(args); } - broadcast_warpper.StoreScalarizedData(args, tid); + broadcast_warpper.StoreScalarizedData(args_out, tid); } -template +template __device__ inline void VectorizedBroadcastKernelImpl( BroadcastArgsWarpper broadcast_warpper, int tid) { - T ins[ET]; - T args[ET][VecSize]; + using OutVecType = CudaAlignedVector; + OutVecType args_out; + InT ins[ET]; + InT args[ET][VecSize]; broadcast_warpper.LoadVectorizedData(args, tid); #pragma unroll(VecSize) @@ -329,13 +333,13 @@ __device__ inline void VectorizedBroadcastKernelImpl( for (int j = 0; j < ET; ++j) { ins[j] = args[j][i]; } - args[0][i] = broadcast_warpper.func(ins); + args_out.val[i] = broadcast_warpper.func(ins); } - broadcast_warpper.StoreVectorizedData(args, tid); + broadcast_warpper.StoreVectorizedData(args_out, tid); } -template +template __global__ void ElementwiseBroadcastKernel( BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) { int tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -345,19 +349,20 @@ __global__ void ElementwiseBroadcastKernel( // eg: Calcualting the front 1024-length data in total 1027 data once VecSize // is 4. if (tid < main_tid) { - VectorizedBroadcastKernelImpl( + VectorizedBroadcastKernelImpl( broadcast_warpper, tid); } // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize. // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is // 4. if (tid < tail_tid) { - ScalarizedBroadcastKernelImpl( + ScalarizedBroadcastKernelImpl( broadcast_warpper, tid); } } -template +template void LaunchBroadcastKernelForDifferentDimSize( const platform::CUDADeviceContext &ctx, const std::vector &ins, framework::Tensor *out, @@ -376,65 +381,73 @@ void LaunchBroadcastKernelForDifferentDimSize( switch (merge_dims.dim_size) { case 1: { - auto broadcast_warpper = BroadcastArgsWarpper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel( + ins, out, vec_len, func, offset_calculator); + ElementwiseBroadcastKernel<<>>( broadcast_warpper, main_tid, tail_tid); break; } case 2: { - auto broadcast_warpper = BroadcastArgsWarpper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel( + ins, out, vec_len, func, offset_calculator); + ElementwiseBroadcastKernel<<>>( broadcast_warpper, main_tid, tail_tid); break; } case 3: { - auto broadcast_warpper = BroadcastArgsWarpper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel( + ins, out, vec_len, func, offset_calculator); + ElementwiseBroadcastKernel<<>>( broadcast_warpper, main_tid, tail_tid); break; } case 4: { - auto broadcast_warpper = BroadcastArgsWarpper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel( + ins, out, vec_len, func, offset_calculator); + ElementwiseBroadcastKernel<<>>( broadcast_warpper, main_tid, tail_tid); break; } case 5: { - auto broadcast_warpper = BroadcastArgsWarpper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel( + ins, out, vec_len, func, offset_calculator); + ElementwiseBroadcastKernel<<>>( broadcast_warpper, main_tid, tail_tid); break; } case 6: { - auto broadcast_warpper = BroadcastArgsWarpper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel( + ins, out, vec_len, func, offset_calculator); + ElementwiseBroadcastKernel<<>>( broadcast_warpper, main_tid, tail_tid); break; } case 7: { - auto broadcast_warpper = BroadcastArgsWarpper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel( + ins, out, vec_len, func, offset_calculator); + ElementwiseBroadcastKernel<<>>( broadcast_warpper, main_tid, tail_tid); break; } case 8: { - auto broadcast_warpper = BroadcastArgsWarpper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel( + ins, out, vec_len, func, offset_calculator); + ElementwiseBroadcastKernel<<>>( broadcast_warpper, main_tid, tail_tid); break; @@ -448,7 +461,7 @@ void LaunchBroadcastKernelForDifferentDimSize( } } -template +template void LaunchBroadcastElementwiseCudaKernel( const platform::CUDADeviceContext &ctx, const std::vector &ins, @@ -457,27 +470,27 @@ void LaunchBroadcastElementwiseCudaKernel( int in_vec_size = 4; framework::Tensor *out = (*outs)[0]; for (auto *in : ins) { - auto temp_size = GetVectorizedSizeImpl(in->data()); + auto temp_size = GetVectorizedSizeImpl(in->data()); in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) : in_vec_size; } - int out_vec_size = GetVectorizedSizeImpl(out->data()); + int out_vec_size = GetVectorizedSizeImpl(out->data()); int vec_size = std::min(out_vec_size, in_vec_size); switch (vec_size) { case 4: { - LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, axis, - func); + LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, + axis, func); break; } case 2: { - LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, axis, - func); + LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, + axis, func); break; } case 1: { - LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, axis, - func); + LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, + axis, func); break; } default: { @@ -502,8 +515,9 @@ void LaunchElementwiseCudaKernel( LaunchSameDimsElementwiseCudaKernel( cuda_ctx, ins, outs, func); } else { - LaunchBroadcastElementwiseCudaKernel( - cuda_ctx, ins, outs, axis, func); + LaunchBroadcastElementwiseCudaKernel(cuda_ctx, ins, outs, axis, + func); } }