diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index fb046481e20c5032cd3073f5c8a6cdee1252d0be..bc3eedacc5b0d284ca873a46974adaa7ba4e46d9 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -112,17 +112,17 @@ __global__ void VectorizedRandomGenerator(const size_t n, auto dst_functor = DstMaskFunctor(1.0f - dropout_prob, is_upscale_in_train); for (; fix < main_offset; fix += stride) { - kps::ReadData(&dst_mask[0], src + fix, deal_size); - kps::ElementwiseRandom( + kps::ReadData(&dst_mask[0], src + fix, deal_size); + kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorTernary>( &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], deal_size); + kps::WriteData(dst + fix, &dst_mask[0], deal_size); // mask - kps::ElementwiseUnary( + kps::ElementwiseUnary( &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData( + kps::WriteData( mask + fix, &mask_result[0], deal_size); if (fix > idx * kCount + 1) { __syncthreads(); @@ -130,17 +130,17 @@ __global__ void VectorizedRandomGenerator(const size_t n, } int remainder = n - fix; if (remainder > 0) { - kps::ReadData(&dst_mask[0], src + fix, remainder); - kps::ElementwiseRandom( + kps::ReadData(&dst_mask[0], src + fix, remainder); + kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorTernary>( &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], remainder); + kps::WriteData(dst + fix, &dst_mask[0], remainder); // mask - kps::ElementwiseUnary( + kps::ElementwiseUnary( &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData( + kps::WriteData( mask + fix, &mask_result[0], remainder); __syncthreads(); } @@ -233,17 +233,17 @@ __global__ void VectorizedGeneratorMask(const size_t n, auto mask_functor = MaskFunctor(1.0f - dropout_prob); for (; fix < main_offset; fix += stride) { - kps::ReadData(&dst_mask[0], src + fix, deal_size); - kps::ElementwiseRandom( + kps::ReadData(&dst_mask[0], src + fix, deal_size); + kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorBinary>( &dst_mask[0], &rands[0], mask_functor, kCount); // mask - kps::ElementwiseUnary( + kps::ElementwiseUnary( &mask_result[0], &dst_mask[0], Cast()); - kps::WriteData( + kps::WriteData( mask + fix, &mask_result[0], deal_size); if (fix > idx * kCount + 1) { __syncthreads(); @@ -251,16 +251,16 @@ __global__ void VectorizedGeneratorMask(const size_t n, } int remainder = n - fix; if (remainder > 0) { - kps::ReadData(&dst_mask[0], src + fix, remainder); - kps::ElementwiseRandom( + kps::ReadData(&dst_mask[0], src + fix, remainder); + kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorBinary>( &dst_mask[0], &rands[0], mask_functor, kCount); // mask - kps::ElementwiseUnary( + kps::ElementwiseUnary( &mask_result[0], &dst_mask[0], Cast()); - kps::WriteData( + kps::WriteData( mask + fix, &mask_result[0], remainder); __syncthreads(); } diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index fa50d5b23bfa2d43bd5676e2a129593c604354b4..2a0881ca0939a0375eb4f1546424aa36fd6811d1 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -73,24 +73,23 @@ __global__ void BroadcastKernelBinary( // load in0 if (use_broadcast[0]) { - kernel_primitives::ReadDataBc( + kernel_primitives::ReadDataBc( arg0, in0, fix, configlists[0], numel); } else { kernel_primitives::ReadData(arg0, in0 + fix, num); } // load in1 if (use_broadcast[1]) { - kernel_primitives::ReadDataBc( + kernel_primitives::ReadDataBc( arg1, in1, fix, configlists[1], numel); } else { - kernel_primitives::ReadData(arg1, in1 + fix, num); + kernel_primitives::ReadData(arg1, in1 + fix, num); } // compute - kernel_primitives::ElementwiseBinary( + kernel_primitives::ElementwiseBinary( result, arg0, arg1, func); // store - kernel_primitives::WriteData( - out + fix, result, num); + kernel_primitives::WriteData(out + fix, result, num); } // bias add forward impl for "[m, n] + [n] = [m, n]" diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 74e48f39185485fac9d55e778645686955b6d606..9b9d9e1d20e12258870be612871d7a0e9abbdeea 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -266,10 +266,10 @@ __device__ __forceinline__ void LoadData( // numel : whole num of output // num: how many data will be deal with in this time if (need_broadcast) { - kps::ReadDataBc( + kps::ReadDataBc( dst, src, block_offset, config, numel, read_lens); } else { - kps::ReadData( + kps::ReadData( dst, src + block_offset, num, read_lens); } } diff --git a/paddle/phi/kernels/funcs/distribution_helper.h b/paddle/phi/kernels/funcs/distribution_helper.h index 0e6b3a3f9d733ea6a43f71daabf6d640bb85b424..abade7ac0ef877a809cba24f3bb740212d2aa2ad 100644 --- a/paddle/phi/kernels/funcs/distribution_helper.h +++ b/paddle/phi/kernels/funcs/distribution_helper.h @@ -278,11 +278,10 @@ __global__ void DistributionKernel(size_t size, MT args[kCount]; T result[kCount]; for (size_t i = idx; i < size; i += total_thread * kCount) { - kps::ElementwiseRandom( - &args[0], dist, &state); - kps::ElementwiseUnary( + kps::ElementwiseRandom(&args[0], dist, &state); + kps::ElementwiseUnary( &result[0], &args[0], trans); - kps::WriteData( + kps::WriteData( out_data + i, &result[0], size - i, 1, stride, 1); __syncthreads(); } diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index ddbbe4b1718f162f8aa3245d504c58bea65a9bec..2573a0e44c90ca3ae671ff8e8665b4cf74d2eeb2 100755 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -519,13 +519,13 @@ struct Loader { kps::Init( args, static_cast(1.0f), read_lens); if (is_boundary) { - kps::ReadData( + kps::ReadData( args, reinterpret_cast(in[Index]) + offset, num, read_lens); } else { - kps::ReadData( + kps::ReadData( args, reinterpret_cast(in[Index]) + offset, num, @@ -595,7 +595,7 @@ struct ElementwisePrimitiveCaller { InT (*args)[VecSize], OutT *result, int read_lens) { - kps::ElementwiseAny( + kps::ElementwiseAny( result, args, func); } }; @@ -606,7 +606,7 @@ struct ElementwisePrimitiveCaller { InT (*args)[VecSize], OutT *result, int read_lens) { - kps::ElementwiseConstant(result, func); + kps::ElementwiseConstant(result, func); } }; @@ -616,7 +616,7 @@ struct ElementwisePrimitiveCaller { InT (*args)[VecSize], OutT *result, int read_lens) { - kps::ElementwiseUnary( + kps::ElementwiseUnary( result, args[0], func); } }; @@ -627,7 +627,7 @@ struct ElementwisePrimitiveCaller { InT (*args)[VecSize], OutT *result, int read_lens) { - kps::ElementwiseBinary( + kps::ElementwiseBinary( result, args[0], args[1], func, read_lens); } }; @@ -638,7 +638,7 @@ struct ElementwisePrimitiveCaller { InT (*args)[VecSize], OutT *result, int read_lens) { - kps::ElementwiseTernary( + kps::ElementwiseTernary( result, args[0], args[1], args[2], func); } }; @@ -703,7 +703,7 @@ struct ElementwiseWriteDataCallerBc { } #pragma unroll for (int i = 0; i < NumOuts; ++i) { - kps::WriteData( + kps::WriteData( outs[i] + block_offset, dst[i], num, read_lens); } } @@ -716,7 +716,7 @@ struct ElementwiseWriteDataCallerBc { kps::IndexType block_offset, int num, int read_lens) { - kps::WriteData( + kps::WriteData( outs[0] + block_offset, src, num, read_lens); } }; diff --git a/paddle/phi/kernels/funcs/index_impl.cu.h b/paddle/phi/kernels/funcs/index_impl.cu.h index f90380bef70bd7ce1376c9b127d59ba5bb0f7ef4..4e2e2a7508700f0351ffcb716679971076d03d8c 100644 --- a/paddle/phi/kernels/funcs/index_impl.cu.h +++ b/paddle/phi/kernels/funcs/index_impl.cu.h @@ -36,18 +36,18 @@ __global__ void VectorizedIndexKernel(T *out, size_t args[VecSize]; T result[VecSize]; for (; data_offset < main_offset; data_offset += stride) { - kps::InitWithDataIndex(&args[0], data_offset); - kps::ElementwiseUnary( + kps::InitWithDataIndex(&args[0], data_offset); + kps::ElementwiseUnary( &result[0], &args[0], func); - kps::WriteData( + kps::WriteData( out + data_offset, &result[0], BLOCK_NUM_X * VecSize); } size_t num = numel - data_offset; if (num > 0) { - kps::InitWithDataIndex(&args[0], data_offset); - kps::ElementwiseUnary( + kps::InitWithDataIndex(&args[0], data_offset); + kps::ElementwiseUnary( &result[0], &args[0], func); - kps::WriteData(out + data_offset, &result[0], num); + kps::WriteData(out + data_offset, &result[0], num); } } diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 4d903e01a49824d8f637afbd35c901bc001653cc..446dfc73d5bd692ba15eabc5ffc3a61d55a3809b 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -712,7 +712,6 @@ __global__ void ReduceAnyKernel(const Tx* x, 1, REDUCE_VEC_SIZE, 1, - 1, Calculator, kps::IdentityFunctor, false>(&input_reg[0], @@ -725,12 +724,11 @@ __global__ void ReduceAnyKernel(const Tx* x, stride, kps::IdentityFunctor(), reduce_last_dim); - kps::ElementwiseUnary( + kps::ElementwiseUnary( &input_compute[0], &input_reg[0], transformer); kps::Reduce( &reduce_var, &input_compute[0], reducer, reduce_last_dim); @@ -742,7 +740,6 @@ __global__ void ReduceAnyKernel(const Tx* x, 1, REDUCE_VEC_SIZE, 1, - 1, Calculator, TransformOp, true>(&input_compute[0], @@ -758,12 +755,11 @@ __global__ void ReduceAnyKernel(const Tx* x, kps::Reduce( &reduce_var, &input_compute[0], reducer, reduce_last_dim); - kps::Reduce( + kps::Reduce( &reduce_var, &reduce_var, reducer, reduce_last_dim); if (is_mean) { reduce_var = reduce_var / static_cast(reduce_num); @@ -807,27 +803,22 @@ __global__ void ReduceHigherDimKernel(const Tx* x, MPType reduce_var = init; MPType reduce_compute = init; for (int loop_idx = 0; loop_idx < loop_size; ++loop_idx) { - kps::ReadData(&reduce_input, - input + loop_idx * left_num + idx, - block.BlockDimX(), - 1, - 1, - left_num); - kps::ElementwiseUnary( + kps::ReadData(&reduce_input, + input + loop_idx * left_num + idx, + block.BlockDimX(), + 1, + 1, + left_num); + kps::ElementwiseUnary( &reduce_compute, &reduce_input, transformer); - kps::Reduce( + kps::Reduce( &reduce_var, &reduce_compute, reducer, false); } if (is_mean) { reduce_var = reduce_var / static_cast(mean_div); } Ty result = static_cast(reduce_var); - kps::WriteData( + kps::WriteData( y + store_offset + idx, &result, block.BlockDimX()); } @@ -835,20 +826,15 @@ __global__ void ReduceHigherDimKernel(const Tx* x, MPType reduce_var = init; MPType reduce_compute = init; for (int loop_idx = 0; loop_idx < loop_size; ++loop_idx) { - kps::ReadData(&reduce_input, - input + loop_idx * left_num + idx, - dim.rem_x, - 1, - 1, - left_num); - kps::ElementwiseUnary( + kps::ReadData(&reduce_input, + input + loop_idx * left_num + idx, + dim.rem_x, + 1, + 1, + left_num); + kps::ElementwiseUnary( &reduce_compute, &reduce_input, transformer); - kps::Reduce( + kps::Reduce( &reduce_var, &reduce_compute, reducer, false); } @@ -856,8 +842,7 @@ __global__ void ReduceHigherDimKernel(const Tx* x, reduce_var = reduce_var / static_cast(mean_div); } Ty result = static_cast(reduce_var); - kps::WriteData( - y + store_offset + idx, &result, dim.rem_x); + kps::WriteData(y + store_offset + idx, &result, dim.rem_x); } } diff --git a/paddle/phi/kernels/funcs/select_impl.cu.h b/paddle/phi/kernels/funcs/select_impl.cu.h index 831e0ca907b3cb728373fcd471f12a8a48977765..4fb1bc13ae7f82ad5d77d12981f76a773804735f 100644 --- a/paddle/phi/kernels/funcs/select_impl.cu.h +++ b/paddle/phi/kernels/funcs/select_impl.cu.h @@ -71,21 +71,21 @@ __device__ void GetBlockCountImpl(const InT *in, int store_fix = BLOCK_ID_X + repeat * GRID_NUM_X; kps::Init(&in_data[0], static_cast(0.0f)); - kps::ReadData(&in_data[0], in, num); - kps::ElementwiseUnary( + kps::ReadData(&in_data[0], in, num); + kps::ElementwiseUnary( &temp[0], &in_data[0], Cast()); - kps::Reduce( + kps::Reduce( &result, &temp[0], Add(), true); - kps::Reduce( + kps::Reduce( &result, &result, Add(), true); if (store_fix == 0) { // first block's fix_size = 0; OutT tmp = static_cast(0.0f); - kps::WriteData(out + store_fix, &tmp, 1); + kps::WriteData(out + store_fix, &tmp, 1); } // store num of this block - kps::WriteData(out + store_fix + 1, &result, 1); + kps::WriteData(out + store_fix + 1, &result, 1); } // Count how many data is not zero in current block @@ -132,12 +132,12 @@ __device__ void CumsumImpl( // set pre_cumsum kps::Init(&temp[0], *pre_cumsum); // load data to arg - kps::ReadData( + kps::ReadData( &arg[0], in, num, 1, BLOCK_NUM_X, 1); // block cumsum - kps::Cumsum(&result[0], &arg[0], func); + kps::Cumsum(&result[0], &arg[0], func); // result = cumsum_result + pre_cumsum - kps::ElementwiseBinary( + kps::ElementwiseBinary( &result[0], &result[0], &temp[0], func); // get the last prefix sum if ((THREAD_ID_X == BLOCK_NUM_X - 1) && !IsBoundary) { @@ -146,7 +146,7 @@ __device__ void CumsumImpl( __syncthreads(); // update pre_cumsum *pre_cumsum = max_thread_data; - kps::WriteData( + kps::WriteData( out, &result[0], num, 1, BLOCK_NUM_X, 1); } @@ -189,7 +189,7 @@ struct SelectCaller { int64_t in_data[VecSize]; OutT store_data[VecSize * phi::DDim::kMaxRank]; // set index - kps::InitWithDataIndex(&in_data[0], data_offset); + kps::InitWithDataIndex(&in_data[0], data_offset); // Get store data according to mask_idt kps::OperatorTernary( store_data, mask_data, &in_data[0], func, VecSize); @@ -215,7 +215,7 @@ struct SelectCaller { int num) { InT in_data[VecSize]; OutT store_data[VecSize * phi::DDim::kMaxRank]; - kps::ReadData(&in_data[0], in, num); + kps::ReadData(&in_data[0], in, num); // Get store data according to mask_idt kps::OperatorTernary( store_data, mask_data, &in_data[0], func, VecSize); @@ -244,7 +244,7 @@ struct SelectCaller { kps::details::ReadData(&in_data[0], in + thread_fix, store_num); kps::OperatorTernary( store_data, mask_data, &in_data[0], func, VecSize); - kps::WriteData(out, &store_data[0], num); + kps::WriteData(out, &store_data[0], num); } }; @@ -285,16 +285,16 @@ __device__ void SelectKernelImpl(OutT *out, kps::Init(&num_thread[0], init_idx); kps::Init(&mask_data[0], init_mask); // Load mask - kps::ReadData(&mask_data[0], mask, num); + kps::ReadData(&mask_data[0], mask, num); // Cast from MT to int - kps::ElementwiseUnary( + kps::ElementwiseUnary( &mask_idt[0], &mask_data[0], Cast()); // Get the num of thread only num_thread[1] has data - kps::Reduce( + kps::Reduce( &num_thread[0], &mask_idt[0], Add(), true); // Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the // thread_fix - kps::Cumsum(&cumsum_thread[0], &num_thread[0], Add()); + kps::Cumsum(&cumsum_thread[0], &num_thread[0], Add()); // get thread_fix int thread_fix = (static_cast(cumsum_thread[0] - num_thread[0]) * store_rank); diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index ef3406fd7f668396e97ea2055f5e88ad84501cd3..ffc6a2e3d6f3276a20f4877c1438050e9863f527 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -311,9 +311,9 @@ __global__ void WarpSoftmaxForward(T* softmax, const VecT* src_v = reinterpret_cast(&src[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&src_data[i][0][0]); - kps::ReadData( + kps::ReadData( ®_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &sub_data[i][0][0], &src_data[i][0][0], DataTransFunctor()); } @@ -321,7 +321,6 @@ __global__ void WarpSoftmaxForward(T* softmax, kps::Reduce, kMode::kLocalMode>( &max[0], &sub_data[0][0][0], ReduceMaxFunctor(), true); @@ -330,15 +329,14 @@ __global__ void WarpSoftmaxForward(T* softmax, // compute sum #pragma unroll for (int i = 0; i < kBatchSize; ++i) { - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor(max[i])); - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &exp_data[i][0][0], &sub_data[i][0][0], ExpFunctor()); } kps::Reduce, kMode::kLocalMode>( &sum[0], &exp_data[0][0][0], kps::AddFunctor(), true); @@ -351,15 +349,15 @@ __global__ void WarpSoftmaxForward(T* softmax, reinterpret_cast(&softmax[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); if (LogMode) { - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &out_tmp[i][0][0], &sub_data[i][0][0], UnarySubFunctor(std::log(sum[i]))); } else { - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &out_tmp[i][0][0], &exp_data[i][0][0], UnaryDivFunctor(sum[i])); } - kps::WriteData( + kps::WriteData( &softmax_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1); } } @@ -417,9 +415,9 @@ __global__ void WarpSoftmaxBackward(T* dst, int ptr = (first_batch + i) * stride; const VecT* src_v = reinterpret_cast(&src[ptr]); const VecT* grad_v = reinterpret_cast(&grad[ptr]); - kps::ReadData( + kps::ReadData( &src_reg[i][0], &src_v[0], idx_max_v[i], 0, kWarpSize, flag); - kps::ReadData( + kps::ReadData( &grad_reg[i][0], &grad_v[0], idx_max_v[i], 0, kWarpSize, flag); } @@ -430,9 +428,9 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad_ptr = reinterpret_cast(&grad_reg[0][0]); constexpr int kStep = kBatchSize * kLoopsV * kVSize; constexpr int kVItem = kLoopsV * kVSize; - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &src_tmp[0][0][0], &src_ptr[0], DataTransFunctor()); - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &grad_tmp[0][0][0], &grad_ptr[0], DataTransFunctor()); // compute sum @@ -444,17 +442,15 @@ __global__ void WarpSoftmaxBackward(T* dst, kps::Reduce, kps::details::ReduceMode::kLocalMode>( &sum[0], &grad_tmp[0][0][0], kps::AddFunctor(), true); } else { - kps::ElementwiseBinary>( + kps::ElementwiseBinary>( &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor()); kps::Reduce, kps::details::ReduceMode::kLocalMode>( &sum[0], &sum_tmp[0][0][0], kps::AddFunctor(), true); @@ -470,17 +466,17 @@ __global__ void WarpSoftmaxBackward(T* dst, AccT* gradptr = reinterpret_cast(&grad_tmp[i][0][0]); AccT* srcptr = reinterpret_cast(&src_tmp[i][0][0]); if (LogMode) { - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &out[i][0][0], &srcptr[0], ExpMulFunctor(sum[i])); - kps::ElementwiseBinary>( + kps::ElementwiseBinary>( &out_tmp[i][0][0], &gradptr[0], &out[i][0][0], kps::SubFunctor()); } else { - kps::ElementwiseUnary>( + kps::ElementwiseUnary>( &out[i][0][0], &gradptr[0], UnarySubFunctor(sum[i])); - kps::ElementwiseBinary>( + kps::ElementwiseBinary>( &out_tmp[i][0][0], &srcptr[0], &out[i][0][0], @@ -488,7 +484,7 @@ __global__ void WarpSoftmaxBackward(T* dst, } VecT* dst_v = reinterpret_cast(&dst[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); - kps::WriteData( + kps::WriteData( &dst_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1); } } @@ -636,7 +632,7 @@ __global__ void NormalSoftmaxForward( } if (blockDim.y > 1) { - kps::Reduce, kMode::kGlobalMode>( + kps::Reduce, kMode::kGlobalMode>( &max_value, &max_value, kps::MaxFunctor(), false); } @@ -647,7 +643,7 @@ __global__ void NormalSoftmaxForward( sum += std::exp(value - max_value); } if (blockDim.y > 1) { - kps::Reduce, kMode::kGlobalMode>( + kps::Reduce, kMode::kGlobalMode>( &sum, &sum, kps::AddFunctor(), false); } @@ -695,7 +691,7 @@ __global__ void NormalSoftmaxBackward(T* input_grad, } } if (blockDim.y > 1) { - kps::Reduce, kMode::kGlobalMode>( + kps::Reduce, kMode::kGlobalMode>( &sum, &sum, kps::AddFunctor(), false); } diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index b5df98671f0b0eb692c0f129ebaf0255556d5ce8..2265077d51bb8bf5287929d4f1864e1961519a0e 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -200,7 +200,6 @@ __device__ inline int GetLastPow2(int n) { * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * OpFunc: Compute functor which has an operator() as following: * template @@ -215,12 +214,7 @@ __device__ inline int GetLastPow2(int n) { * in: The register pointer of in, the size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in, OpFunc compute) { @@ -239,7 +233,6 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, * OutT: The data type of out. * NX: The number of data columns computed by each thread. * NY: The number of data rows computed by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * OpFunc: Compute functor which has an operator() as following: * template @@ -255,12 +248,7 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, * in2: The register pointer of second input, size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1, const InT* in2, @@ -271,12 +259,7 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, } } -template +template __device__ __forceinline__ void ElementwiseBinary( OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) { #pragma unroll @@ -294,7 +277,6 @@ __device__ __forceinline__ void ElementwiseBinary( * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * OpFunc: Compute functor which has an operator() as following * template @@ -312,12 +294,7 @@ __device__ __forceinline__ void ElementwiseBinary( * in3: The register pointer of third input, size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseTernary( OutT* out, const InT* in1, const InT* in2, const InT* in3, OpFunc compute) { #pragma unroll @@ -335,7 +312,6 @@ __device__ __forceinline__ void ElementwiseTernary( * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * Arity: The size of ins. * OpFunc: Compute functor which has an operator() as following: @@ -351,13 +327,7 @@ __device__ __forceinline__ void ElementwiseTernary( * ins: A pointers of array consisting of multiple inputs. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseAny(OutT* out, InT (*ins)[NX * NY], OpFunc compute) { @@ -382,7 +352,6 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * OpFunc: Compute functor which has an operator() as following * template @@ -398,12 +367,7 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, * in2: The register pointer of second input, size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void CycleBinary(OutT* out, const InT* in1, const InT* in2, @@ -428,7 +392,6 @@ __device__ __forceinline__ void CycleBinary(OutT* out, * T: The type of data. * NX: The number of data continuously loaded by each thread. * NY: The number of data rows loaded by each thread, only NY = 1 was supported. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * ReduceFunctor: Compute functor which has an operator() as following * template @@ -448,7 +411,6 @@ __device__ __forceinline__ void CycleBinary(OutT* out, template __device__ __forceinline__ void Reduce(T* out, @@ -494,7 +456,6 @@ __device__ __forceinline__ void Reduce(T* out, * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. Currently only * GPU was supported. * OpFunc: Compute functor which has an operator() as following * template @@ -509,12 +470,7 @@ __device__ __forceinline__ void Reduce(T* out, * out: The register pointer of out, the size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { #pragma unroll for (int idx = 0; idx < NX * NY; idx++) { @@ -532,7 +488,6 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { * hiprandStatePhilox4_32_10_t. * OutT: the type of out register. * ReturnsCount: The number of random data generated by OpFunc. - * BlockSize: Identifies the current device thread index method. Currently only * GPU was supported. * OpFunc: Compute functor which has an operator() as following * template @@ -549,11 +504,7 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseRandom(OutT* out, OpFunc compute, StateType* state) { @@ -571,7 +522,6 @@ __device__ __forceinline__ void ElementwiseRandom(OutT* out, * @template paraments * InT: the type of input register. * OutT: the type of out register. - * BlockSize: Identifies the current device thread index method. Currently only * GPU was supported. * OpFunc: Compute functor which has an operator() as following * template @@ -589,7 +539,7 @@ __device__ __forceinline__ void ElementwiseRandom(OutT* out, */ #define SHARED_SIZE_LIMIT 512 -template +template __device__ __forceinline__ void Cumsum(OutT* out, const InT* in, OpFunc compute) { @@ -632,7 +582,6 @@ __device__ __forceinline__ void Cumsum(OutT* out, * @template paraments * InT: the type of input register. * OutT: the type of out register. - * BlockSize: Identifies the current device thread index method. Currently only * GPU was supported. * * @param @@ -645,7 +594,7 @@ __device__ __forceinline__ void Cumsum(OutT* out, #define SHARED_SIZE_LIMIT 1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must // larger than blockDim.x * 2 -template +template __device__ __forceinline__ void Sort(OutT* out, const InT* in, int num, @@ -689,7 +638,6 @@ __device__ __forceinline__ void Sort(OutT* out, * InT: The type of input register. * OutT: The type of out register. * IndexType: The type of index. - * BlockSize: Identifies the current device thread index method. Currently only * GPU was supported. * * @param @@ -701,7 +649,7 @@ __device__ __forceinline__ void Sort(OutT* out, * monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles * sorted in escending. */ -template +template __device__ __forceinline__ void Sort(OutT* out, IndexType* out_index, const InT* in, diff --git a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h index 38a8d40aee628798e26ab3d7132d783deb6471a7..2fecebaf3d2687243d759037d6381f694f0be55c 100644 --- a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h @@ -89,7 +89,6 @@ __device__ void BlockXReduce(T* out, const T* data, OpFunc reducer) { * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * OpFunc: Compute functor which has an operator() as following: * template @@ -104,12 +103,7 @@ __device__ void BlockXReduce(T* out, const T* data, OpFunc reducer) { * in: The register pointer of in, the size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in, OpFunc compute) { @@ -128,7 +122,6 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, * OutT: The data type of out. * NX: The number of data columns computed by each thread. * NY: The number of data rows computed by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * OpFunc: Compute functor which has an operator() as following: * template @@ -144,12 +137,7 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, * in2: The register pointer of second input, size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1, const InT* in2, @@ -160,12 +148,7 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, } } -template +template __device__ __forceinline__ void ElementwiseBinary( OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) { for (int idx = 0; idx < read_lens; ++idx) { @@ -182,7 +165,6 @@ __device__ __forceinline__ void ElementwiseBinary( * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * OpFunc: Compute functor which has an operator() as following * template @@ -200,12 +182,7 @@ __device__ __forceinline__ void ElementwiseBinary( * in3: The register pointer of third input, size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseTernary( OutT* out, const InT* in1, const InT* in2, const InT* in3, OpFunc compute) { #pragma unroll @@ -223,7 +200,6 @@ __device__ __forceinline__ void ElementwiseTernary( * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * Arity: The size of ins * OpFunc: Compute functor which has an operator() as following: @@ -239,13 +215,7 @@ __device__ __forceinline__ void ElementwiseTernary( * ins: A pointers of array consisting of multiple inputs. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseAny(OutT* out, InT (*ins)[NX * NY], OpFunc compute) { @@ -270,7 +240,6 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * OpFunc: Compute functor which has an operator() as following * template @@ -286,12 +255,7 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, * in2: The register pointer of second input, size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void CycleBinary(OutT* out, const InT* in1, const InT* in2, @@ -316,7 +280,6 @@ __device__ __forceinline__ void CycleBinary(OutT* out, * T: The type of data. * NX: The number of data continuously loaded by each thread. * NY: The number of data rows loaded by each thread, only NY = 1 was supported. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * ReduceFunctor: Compute functor which has an operator() as following * template @@ -336,7 +299,6 @@ __device__ __forceinline__ void CycleBinary(OutT* out, template __device__ __forceinline__ void Reduce(T* out, @@ -369,7 +331,6 @@ __device__ __forceinline__ void Reduce(T* out, * OutT: The data type of out. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * OpFunc: Compute functor which has an operator() as following * template @@ -384,12 +345,7 @@ __device__ __forceinline__ void Reduce(T* out, * out: The register pointer of out, the size is NX * NY. * compute: Compute function which was declared like OpFunc(). */ -template +template __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { #pragma unroll for (int idx = 0; idx < NX * NY; idx++) { diff --git a/paddle/phi/kernels/primitive/datamover_primitives.h b/paddle/phi/kernels/primitive/datamover_primitives.h index bf60d1610e322ed006e0d1d2ec317b0239a8b9b4..3f6148c7efd8000fb5d014a7e3433b2b756b2a96 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives.h +++ b/paddle/phi/kernels/primitive/datamover_primitives.h @@ -144,7 +144,6 @@ __device__ __forceinline__ void ReadData(T* dst, * Ty: The type of data that needs to be stored in registers. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * IsBoundary: Indicates whether to perform block access storage out-of-bounds * judgment. When the number of data processed by the block is less than @@ -161,12 +160,7 @@ __device__ __forceinline__ void ReadData(T* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, int size_nx, @@ -275,7 +269,6 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) { * T: The type of data. * NX: Each thread load NX data from global memory continuously. * NY: Each thread need to load NY rows, only NY = 1 was supported. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * IsBoundary: Whether to make an out-of-bounds judgment on access to memory. * When the number of data processed by this block is less than @@ -287,7 +280,7 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) { * src: The data pointer of the current block. * size: The current block needs to load size data continuously. */ -template +template __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, int num) { @@ -319,7 +312,7 @@ __device__ __forceinline__ void ReadData(T* dst, } } -template +template __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, int num, @@ -361,7 +354,6 @@ __device__ __forceinline__ void ReadData(T* dst, * NY: Each thread need to load NY rows, only NY = 1 was supported. * ArgsT: The Type if dst, ArgsT can be std::tuple or std::tuple * Index: The index of data stored in dst. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * IsBoundary: Whether to make an out-of-bounds judgment on access to memory. * When the number of data processed by this block is less than @@ -376,7 +368,6 @@ __device__ __forceinline__ void ReadData(T* dst, template @@ -419,7 +410,6 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * T: The type of data stored in the global memory. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * IsBoundary: Indicates whether to perform block access storage out-of-bounds @@ -437,7 +427,7 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __forceinline__ void ReadDataBc( T* dst, const T* __restrict__ src, @@ -479,7 +469,6 @@ __device__ __forceinline__ void ReadDataBc( * T: The type of data. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * IsBoundary: Indicates whether to perform block access storage out-of-bounds @@ -507,7 +496,6 @@ template +template __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, int num) { @@ -613,7 +600,7 @@ __device__ __forceinline__ void WriteData(T* dst, } } -template +template __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, int num, @@ -652,7 +639,6 @@ __device__ __forceinline__ void WriteData(T* dst, * Ty: The type of data that stored in the global memory. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * IsBoundary: Indicates whether to perform block access storage out-of-bounds * judgment. When the number of data processed by the block is less than @@ -669,12 +655,7 @@ __device__ __forceinline__ void WriteData(T* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __forceinline__ void WriteData(Ty* dst, const Tx* __restrict__ src, int size_nx, @@ -766,7 +747,6 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { * T: The type of data stored in the global memory. * NX: The number of data continuously loaded by each thread. * NY: The number of data rows loaded by each thread, only NY = 1 was supported. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * IsBoundary: Indicates whether to perform block access storage out-of-bounds @@ -782,7 +762,7 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { * coordinate mapping relationship between output data and input data. * total_num_output: Total number of original output. */ -template +template __device__ __forceinline__ void ReadDataBc( T* dst, const T* __restrict__ src, @@ -820,14 +800,13 @@ __device__ __forceinline__ void ReadDataBc( * T: Data type of register. * NX: Number of data to initialize. * NY: Number of data to initialize, NY only can be 1. - * BlockSize: Identifies the current device thread index method. For GPU, * threadIdx.x is used as the thread index. Currently only GPU was supported. * * @param: * dst: The register pointer of the thread, the size is NX. * init_data: The register pointer of init data, the size is NX. */ -template +template __device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) { int thread_offset = block_offset + threadIdx.x * NX; #pragma unroll diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index 2915463f5fc916c1fbaf06389e3685da8408daf3..14a66516f5b632c2b94343645e04d7ae5165b983 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -337,7 +337,6 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst, * Ty: The type of data that needs to be stored in registers. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * IsBoundary: Indicates whether to perform block access storage out-of-bounds * judgment. When the number of data processed by the block is less than @@ -354,12 +353,7 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __inline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, int size_nx, @@ -472,7 +466,6 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) { * T: The type of data. * NX: Each thread load NX data from global memory continuously. * NY: Each thread need to load NY rows, only NY = 1 was supported. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * IsBoundary: Whether to make an out-of-bounds judgment on access to memory. * When the number of data processed by this block is less than @@ -484,7 +477,7 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) { * src: The data pointer of the current block. * size: The current block needs to load size data continuously. */ -template +template __device__ __inline__ void ReadData(T* dst, const T _global_ptr_* src, int num) { @@ -502,7 +495,7 @@ __device__ __inline__ void ReadData(T* dst, } } -template +template __device__ __inline__ void ReadData(T* dst, const T _global_ptr_* src, int num, @@ -531,7 +524,6 @@ __device__ __inline__ void ReadData(T* dst, * NY: Each thread need to load NY rows, only NY = 1 was supported. * ArgsT: The Type if dst, ArgsT can be std::tuple or std::tuple * Index: The index of data stored in dst. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * IsBoundary: Whether to make an out-of-bounds judgment on access to memory. * When the number of data processed by this block is less than @@ -546,7 +538,6 @@ __device__ __inline__ void ReadData(T* dst, template @@ -582,7 +573,6 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * T: The type of data stored in the global memory. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * IsBoundary: Indicates whether to perform block access storage out-of-bounds * judgment. When the number of data processed by the block is less than @@ -599,7 +589,7 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, uint32_t block_offset, @@ -634,7 +624,6 @@ __device__ __inline__ void ReadDataBc(T* dst, * T: The type of data. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * IsBoundary: Indicates whether to perform block access storage out-of-bounds @@ -662,7 +651,6 @@ template +template __device__ void WriteData(T _global_ptr_* dst, const T* src, int num, @@ -766,7 +753,7 @@ __device__ void WriteData(T _global_ptr_* dst, } } -template +template __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { int thread_offset = core_id() * NX; mfence_local(); @@ -793,7 +780,6 @@ __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { * Ty: The type of data stored in the global memory. * NX: The number of data columns loaded by each thread. * NY: The number of data rows loaded by each thread. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * IsBoundary: Indicates whether to perform block access storage out-of-bounds * judgment. When the number of data processed by the block is less than @@ -810,12 +796,7 @@ __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __inline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, int size_nx, @@ -1190,7 +1171,6 @@ __device__ __inline__ void ReadDataBcCanNotCmp( * T: The type of data stored in the global memory. * NX: The number of data continuously loaded by each thread. * NY: The number of data rows loaded by each thread, only NY = 1 was supported. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * IsBoundary: Indicates whether to perform block access storage out-of-bounds * judgment. When the number of data processed by the block is less than @@ -1206,7 +1186,7 @@ __device__ __inline__ void ReadDataBcCanNotCmp( * read_lens: The number of data continuously loaded by each thread. * total_num_output: Total number of original output. */ -template +template __device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, uint32_t block_offset, @@ -1238,14 +1218,13 @@ __device__ __inline__ void ReadDataBc(T* dst, * T: Data type of register. * NX: Number of data to initialize. * NY: Number of data to initialize, NY only can be 1. - * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. * * @param: * dst: The register pointer of the thread, the size is NX. * init_data: The register pointer of init data, the size is NX. */ -template +template __device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) { int thread_offset = block_offset + core_id() * NX; #pragma unroll