From d611e48c90d1a9145f97956ca2e5faea7a4a16bd Mon Sep 17 00:00:00 2001 From: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Date: Fri, 28 Apr 2023 18:45:44 +0800 Subject: [PATCH] Dropout optimize & clean broadcast inT and ElementwiseType (#52969) * change judgement for DropoutGradGPUKernelDriver * add UnrollerWithoutVecSize and after this Loaddata to be refined * pass unittest * use same unroller with XPU * BroadcastWithInt64Index * BroadcastDataLoader template partial specialization * fix compile errs in ROCms * clean ElementwiseT and InT for BroadcastKernel * default axis and clean inT * remove redundant fast divmod computation * optimize drop_nd & drop_nd_grad * optimize BroadcastDataLoader bf16 fp16 * rm InT etc. after merge develop * delete constexpr for windows ci * fix conflict * fix conflic with develop * fix conflic * new clean * clean --- .../elementwise/elementwise_op_broadcast.cu.h | 14 +-- .../elementwise/elementwise_op_function.h | 11 +- .../elementwise/elementwise_op_impl.cu.h | 2 - paddle/fluid/operators/fused/attn_gemm.h | 4 +- paddle/fluid/operators/fused/attn_gemm_int8.h | 8 +- paddle/fluid/operators/fused/fmha_ref.h | 22 ++-- .../operators/fused/fused_gate_attention.h | 8 +- .../fluid/operators/fused_token_prune_op.cu | 3 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 11 +- paddle/phi/kernels/cpu/bitwise_kernel.cc | 18 ++-- paddle/phi/kernels/cpu/compare_kernel.cc | 6 +- paddle/phi/kernels/cpu/dirichlet_kernel.cc | 4 +- .../kernels/cpu/elementwise_divide_kernel.cc | 4 +- paddle/phi/kernels/cpu/elementwise_kernel.cc | 2 +- .../phi/kernels/cpu/layer_norm_grad_kernel.cc | 80 +++++--------- paddle/phi/kernels/cpu/layer_norm_kernel.cc | 22 ++-- paddle/phi/kernels/cpu/logical_kernel.cc | 2 +- .../phi/kernels/cpu/matrix_rank_tol_kernel.cc | 11 +- paddle/phi/kernels/funcs/broadcast_function.h | 101 +++++++++--------- paddle/phi/kernels/funcs/dropout_impl.cu.h | 93 ++++------------ paddle/phi/kernels/funcs/elementwise_base.h | 5 +- paddle/phi/kernels/fusion/gpu/attn_gemm.h | 4 +- paddle/phi/kernels/fusion/gpu/fmha_ref.h | 22 ++-- paddle/phi/kernels/gpu/dirichlet_kernel.cu | 4 +- .../gpu/elementwise_divide_grad_kernel.cu | 21 ++-- paddle/phi/kernels/gpu/elementwise_grad.h | 51 +++++---- .../kernels/gpu/elementwise_grad_kernel.cu | 42 ++++---- paddle/phi/kernels/gpu/expand_as_kernel.cu | 3 +- paddle/phi/kernels/gpu/expand_kernel.cu | 3 +- .../phi/kernels/gpu/matrix_rank_tol_kernel.cu | 5 +- .../phi/kernels/gpu/reduce_amin_amax_common.h | 12 +-- paddle/phi/kernels/gpu/reduce_grad.h | 16 ++- .../phi/kernels/gpu/reduce_max_grad_kernel.cu | 8 +- .../kernels/gpu/reduce_mean_grad_kernel.cu | 4 +- .../phi/kernels/gpu/reduce_min_grad_kernel.cu | 8 +- .../phi/kernels/gpu/reduce_sum_grad_kernel.cu | 2 +- .../gpu/squared_l2_norm_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/tile_kernel.cu | 8 +- .../phi/kernels/gpu/viterbi_decode_kernel.cu | 3 +- paddle/phi/kernels/impl/complex_kernel_impl.h | 6 +- .../impl/elementwise_grad_kernel_impl.h | 8 +- .../kernels/impl/elementwise_kernel_impl.h | 16 +-- paddle/phi/kernels/impl/lu_kernel_impl.h | 12 +-- .../phi/kernels/impl/set_value_kernel_impl.h | 2 - paddle/phi/kernels/kps/bitwise_kernel.cu | 23 ++-- paddle/phi/kernels/kps/compare_kernel.cu | 3 +- paddle/phi/kernels/kps/elementwise_kernel.cu | 4 +- paddle/phi/kernels/kps/logical_kernel.cu | 31 +++--- .../kernels/legacy/cpu/elementwise_kernel.cc | 16 +-- .../legacy/kps/elementwise_raw_kernel.cu | 21 ++-- .../cpp/phi/kernels/test_ternary_broadcast.cu | 3 +- 51 files changed, 334 insertions(+), 461 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 7ca62b1969e..94ce32cdaf1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -19,17 +19,13 @@ namespace paddle { namespace operators { -template +template void LaunchElementwiseCudaKernel( const KPDevice &ctx, const std::vector &ins, std::vector *outs, - int axis, - Functor func) { + Functor func, + int axis = -1) { std::vector pt_inputs; std::vector pt_outputs; // TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary @@ -53,8 +49,8 @@ void LaunchElementwiseCudaKernel( for (int i = 0; i < pt_outputs_tmp.size(); i++) { pt_outputs.push_back(pt_outputs_tmp[i].get()); } - phi::funcs::BroadcastKernel( - ctx, pt_inputs, &pt_outputs, axis, func); + phi::funcs::BroadcastKernel( + ctx, pt_inputs, &pt_outputs, func, axis); } } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index b1d5f13bf85..c69acb89750 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -188,7 +188,7 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, z->mutable_data(ctx.GetPlace()); const auto &dev_ctx = ctx.template device_context(); phi::funcs::ElementwiseCompute( - dev_ctx, *x, *y, axis, func, z); + dev_ctx, *x, *y, func, z, axis); } // FusedElemwiseAndAct @@ -1596,7 +1596,7 @@ static inline std::vector GetReduceDim(const framework::DDim &in, #if defined(__NVCC__) || defined(__HIPCC__) -template +template void GetGradXAndYOut(const phi::GPUContext &dev_ctx, const platform::Place &place, int axis, @@ -1605,11 +1605,11 @@ void GetGradXAndYOut(const phi::GPUContext &dev_ctx, phi::DenseTensor *dx, phi::DenseTensor *dy, Functor func) { - phi::GetGradXAndYOut( + phi::GetGradXAndYOut( dev_ctx, place, axis, ins, *dout, dx, dy, func); } -template +template void GetGradXOrYOut(const phi::GPUContext &dev_ctx, const platform::Place &place, int axis, @@ -1617,8 +1617,7 @@ void GetGradXOrYOut(const phi::GPUContext &dev_ctx, const phi::DenseTensor *dout, phi::DenseTensor *dxy, Functor func) { - phi::GetGradXOrYOut( - dev_ctx, place, axis, ins, *dout, dxy, func); + phi::GetGradXOrYOut(dev_ctx, place, axis, ins, *dout, dxy, func); } #endif diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 2680c8ecc5e..47e31783140 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -23,8 +23,6 @@ limitations under the License. */ namespace paddle { namespace operators { -using ElementwiseType = phi::ElementwiseType; - template void LaunchSameDimsElementwiseCudaKernel( const KPDevice &ctx, diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index 9709f60bbc1..277e29c4d59 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -109,8 +109,8 @@ class AttnMatMul { // bias_out = output + bias std::vector ins = {output, bias}; std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); } } diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index c61a7f60d43..705cb8ece41 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -85,8 +85,8 @@ class AttnMatmulINT8 { // bias_out = output + bias std::vector ins = {output, bias}; std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess, platform::errors::Fatal( @@ -139,8 +139,8 @@ class AttnMatmulINT8 { // bias_out = output + bias std::vector ins = {output, bias}; std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess, platform::errors::Fatal( diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 1d83c7a62b1..843b5009a6f 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -255,12 +255,11 @@ class FMHARef { ins.emplace_back(src_mask_tensor); outs.emplace_back(src_mask_out_tensor); int elewise_add_axis = -1; - phi::funcs::BroadcastKernel( - dev_ctx_, - ins, - &outs, - elewise_add_axis, - phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel(dev_ctx_, + ins, + &outs, + phi::funcs::AddFunctor(), + elewise_add_axis); phi::SoftmaxForwardCUDAKernelDriver( dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); @@ -432,12 +431,11 @@ class FMHARef { ins.emplace_back(src_mask_tensor); outs.emplace_back(src_mask_out_tensor); int elewise_add_axis = -1; - phi::funcs::BroadcastKernel( - dev_ctx_, - ins, - &outs, - elewise_add_axis, - phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel(dev_ctx_, + ins, + &outs, + phi::funcs::AddFunctor(), + elewise_add_axis); phi::SoftmaxForwardCUDAKernelDriver( dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index c8c4733df2e..105647baf1c 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -689,13 +689,13 @@ class FMHAGateRef { std::vector ins = { qk_out, src_mask, nonbatched_bias}; std::vector outs = {qk_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, TernaryAddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, TernaryAddFunctor()); } else { std::vector ins = {qk_out, src_mask}; std::vector outs = {qk_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); } phi::SoftmaxForwardCUDAKernelDriver(dev_ctx_, *qk_out, -1, softmax_out); } diff --git a/paddle/fluid/operators/fused_token_prune_op.cu b/paddle/fluid/operators/fused_token_prune_op.cu index 434c072e5aa..8f0a53611f3 100644 --- a/paddle/fluid/operators/fused_token_prune_op.cu +++ b/paddle/fluid/operators/fused_token_prune_op.cu @@ -141,8 +141,7 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel { ins.emplace_back(attn); ins.emplace_back(mask); outs.emplace_back(&attn_tmp); - LaunchElementwiseCudaKernel( - dev_ctx, ins, &outs, -1, AttnMaskFunctor()); + LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, AttnMaskFunctor()); // 2. Reduce sum const std::vector reduce_dims{1, 2}; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 3349400a2f9..5cea4fa9e05 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -834,12 +834,11 @@ class ReduceCudaGradKernel : public framework::OpKernel { } using MPType = typename kps::details::MPTypeTrait::Type; - phi::ReduceGrad>( - dev_ctx, - pt_d_out.get(), - pt_d_x.get(), - pt_out_dtype, - TransformOp(reduce_num)); + phi::ReduceGrad>(dev_ctx, + pt_d_out.get(), + pt_d_x.get(), + pt_out_dtype, + TransformOp(reduce_num)); } }; diff --git a/paddle/phi/kernels/cpu/bitwise_kernel.cc b/paddle/phi/kernels/cpu/bitwise_kernel.cc index 80424ef624f..a6297efd9cd 100644 --- a/paddle/phi/kernels/cpu/bitwise_kernel.cc +++ b/paddle/phi/kernels/cpu/bitwise_kernel.cc @@ -24,15 +24,15 @@ limitations under the License. */ namespace phi { -#define DEFINE_BITWISE_KERNEL(op_type) \ - template \ - void Bitwise##op_type##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - DenseTensor* out) { \ - funcs::Bitwise##op_type##Functor func; \ - funcs::ElementwiseCompute, T, T>( \ - dev_ctx, x, y, -1, func, out); \ +#define DEFINE_BITWISE_KERNEL(op_type) \ + template \ + void Bitwise##op_type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + funcs::Bitwise##op_type##Functor func; \ + funcs::ElementwiseCompute, T>( \ + dev_ctx, x, y, func, out); \ } DEFINE_BITWISE_KERNEL(And) diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index cf8eb47fb42..0fd1332d76c 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -33,10 +33,10 @@ inline void CompareKernelImpl(const Context& ctx, ctx.template Alloc(out); if (x.dims().size() >= y.dims().size()) { funcs::ElementwiseCompute( - ctx, x, y, axis, Functor(), out); + ctx, x, y, Functor(), out, axis); } else { funcs::ElementwiseCompute( - ctx, x, y, axis, InverseFunctor(), out); + ctx, x, y, InverseFunctor(), out, axis); } } @@ -59,7 +59,7 @@ inline void CompareAllKernelImpl(const Context& ctx, tmp_data[0] = Functor()(x.data()[0], y.data()[0]); } else { funcs::ElementwiseCompute( - ctx, x, y, 0, Functor(), &tmp); + ctx, x, y, Functor(), &tmp, 0); } auto tmp_flat = EigenVector::Flatten(tmp); auto out_es = EigenScalar::From(*out); diff --git a/paddle/phi/kernels/cpu/dirichlet_kernel.cc b/paddle/phi/kernels/cpu/dirichlet_kernel.cc index c124920dfa0..855e6bdfe1e 100644 --- a/paddle/phi/kernels/cpu/dirichlet_kernel.cc +++ b/paddle/phi/kernels/cpu/dirichlet_kernel.cc @@ -91,8 +91,8 @@ struct DirichletSampler { true, false); - funcs::ElementwiseCompute, T, T>( - dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor(), out); + funcs::ElementwiseCompute, T>( + dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor(), out); } }; diff --git a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc index 40d0e863ea5..8dbbfc13b81 100644 --- a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc @@ -38,10 +38,10 @@ void DivideRawKernel(const Context& dev_ctx, auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::DivideFunctor(), out); + dev_ctx, x, y, funcs::DivideFunctor(), out, axis); } else { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::InverseDivideFunctor(), out); + dev_ctx, x, y, funcs::InverseDivideFunctor(), out, axis); } } } diff --git a/paddle/phi/kernels/cpu/elementwise_kernel.cc b/paddle/phi/kernels/cpu/elementwise_kernel.cc index 9b564679b35..5af36cfbf20 100644 --- a/paddle/phi/kernels/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_kernel.cc @@ -75,7 +75,7 @@ void HeavisideKernel(const Context& dev_ctx, // allocate memory for out dev_ctx.template Alloc(out); funcs::ElementwiseCompute, T>( - dev_ctx, x, y, -1, funcs::ElementwiseHeavisideFunctor(), out); + dev_ctx, x, y, funcs::ElementwiseHeavisideFunctor(), out); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc index c42e423ba2d..630e786b571 100644 --- a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc @@ -68,20 +68,15 @@ void LayerNormGradKernel(const Context& dev_ctx, temp_norm.Resize(matrix_shape); dev_ctx.template Alloc(&temp_norm); // get x_norm - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - x_tmp, - mean, - /*axis*/ 0, - funcs::SubtractFunctor(), - &temp_norm); - phi::funcs::ElementwiseCompute, T, T>( + phi::funcs::ElementwiseCompute, T>( + dev_ctx, x_tmp, mean, funcs::SubtractFunctor(), &temp_norm, 0); + phi::funcs::ElementwiseCompute, T>( dev_ctx, temp_norm, variance, - /*axis*/ 0, funcs::DivAndSqrtFunctor(static_cast(epsilon)), - &temp_norm); + &temp_norm, + 0); } if (d_bias) { @@ -90,8 +85,8 @@ void LayerNormGradKernel(const Context& dev_ctx, } if (d_scale) { dev_ctx.template Alloc(d_scale); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, temp_norm, d_y, 0, funcs::MultiplyFunctor(), &temp); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, temp_norm, d_y, funcs::MultiplyFunctor(), &temp, 0); colwise_sum(dev_ctx, temp, d_scale); } @@ -107,70 +102,45 @@ void LayerNormGradKernel(const Context& dev_ctx, if (d_scale) { // dy_dx - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, d_y, *scale, /*axis*/ 1, funcs::MultiplyFunctor(), &temp); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, d_y, *scale, funcs::MultiplyFunctor(), &temp, 1); phi::Copy(dev_ctx, temp, dev_ctx.GetPlace(), false, d_x); // dy_dmean_dx row_mean(dev_ctx, temp, &temp_vec); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - *d_x, - temp_vec, - /*axis*/ 0, - funcs::SubtractFunctor(), - d_x); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, *d_x, temp_vec, funcs::SubtractFunctor(), d_x, 0); // dy_var_dx - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - temp, - temp_norm, - /*axis*/ 0, - funcs::MultiplyFunctor(), - &temp); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, temp, temp_norm, funcs::MultiplyFunctor(), &temp, 0); } else { // dy_dx phi::Copy(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x); // dy_dmean_dx row_mean(dev_ctx, d_y, &temp_vec); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - *d_x, - temp_vec, - /*axis*/ 0, - funcs::SubtractFunctor(), - d_x); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, *d_x, temp_vec, funcs::SubtractFunctor(), d_x, 0); // dy_var_dx - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - d_y, - temp_norm, - /*axis*/ 0, - funcs::MultiplyFunctor(), - &temp); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, d_y, temp_norm, funcs::MultiplyFunctor(), &temp, 0); } // dy_var_dx row_mean(dev_ctx, temp, &temp_vec); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - temp_norm, - temp_vec, - /*axis*/ 0, - funcs::MultiplyFunctor(), - &temp); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, *d_x, temp, /*axis*/ 0, funcs::SubtractFunctor(), d_x); - - phi::funcs::ElementwiseCompute, T, T>( + phi::funcs::ElementwiseCompute, T>( + dev_ctx, temp_norm, temp_vec, funcs::MultiplyFunctor(), &temp, 0); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, *d_x, temp, funcs::SubtractFunctor(), d_x, 0); + + phi::funcs::ElementwiseCompute, T>( dev_ctx, *d_x, variance, - /*axis*/ 0, funcs::DivAndSqrtFunctor(static_cast(epsilon)), - d_x); + d_x, + 0); d_x->Resize(dx_dim); } } diff --git a/paddle/phi/kernels/cpu/layer_norm_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_kernel.cc index 1c82866f0bb..2a93d03b4ab 100644 --- a/paddle/phi/kernels/cpu/layer_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_kernel.cc @@ -67,30 +67,30 @@ void LayerNormKernel(const Context& dev_ctx, // get variance - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, x_tmp, *mean, 0, funcs::SubAndSquareFunctor(), &out); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, x_tmp, *mean, funcs::SubAndSquareFunctor(), &out, 0); row_mean(dev_ctx, out, var); // get x_norm - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, x_tmp, *mean, 0, funcs::SubtractFunctor(), &out); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, x_tmp, *mean, funcs::SubtractFunctor(), &out, 0); - phi::funcs::ElementwiseCompute, T, T>( + phi::funcs::ElementwiseCompute, T>( dev_ctx, out, *var, - 0, funcs::DivAndSqrtFunctor(static_cast(epsilon)), - &out); + &out, + 0); if (scale) { - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, out, *scale, 1, funcs::MultiplyFunctor(), &out); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, out, *scale, funcs::MultiplyFunctor(), &out, 1); } if (bias) { - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, out, *bias, 1, funcs::AddFunctor(), &out); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, out, *bias, funcs::AddFunctor(), &out, 1); } #else PADDLE_ENFORCE_EQ(mean->numel(), diff --git a/paddle/phi/kernels/cpu/logical_kernel.cc b/paddle/phi/kernels/cpu/logical_kernel.cc index 3a669883f78..38c927e976f 100644 --- a/paddle/phi/kernels/cpu/logical_kernel.cc +++ b/paddle/phi/kernels/cpu/logical_kernel.cc @@ -32,7 +32,7 @@ namespace phi { DenseTensor* out) { \ funcs::Logical##type##Functor binary_func; \ funcs::ElementwiseCompute, T, bool>( \ - dev_ctx, x, y, -1, binary_func, out); \ + dev_ctx, x, y, binary_func, out); \ } DEFINE_LOGICAL_BINARY_KERNEL(And) diff --git a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc index fbb16138567..e1188fda486 100644 --- a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc @@ -132,11 +132,10 @@ void MatrixRankTolKernel(const Context& dev_ctx, DenseTensor tol_tensor; tol_tensor.Resize(dim_out); dev_ctx.template Alloc(&tol_tensor); - funcs::ElementwiseCompute, T, T>( + funcs::ElementwiseCompute, T>( dev_ctx, atol_tensor, rtol_tensor, - -1, GreaterElementFunctor(), &tol_tensor); @@ -151,17 +150,17 @@ void MatrixRankTolKernel(const Context& dev_ctx, dev_ctx, eigenvalue_tensor, tol_tensor, - axis, funcs::GreaterThanFunctor(), - &compare_result); + &compare_result, + axis); } else { funcs::ElementwiseCompute, T, int>( dev_ctx, eigenvalue_tensor, tol_tensor, - axis, funcs::LessThanFunctor(), - &compare_result); + &compare_result, + axis); } phi::SumKernel(dev_ctx, diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index b4125ab5550..d0f3fba392a 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -189,45 +189,29 @@ struct BroadcastDataLoader { } }; -// Common broadcast data loader. -template -struct BroadcastDataLoader { - template - static __device__ __forceinline__ void Apply(const Array1 &ins, - ArgsT *args, - const Array2 &configs, - const Array3 &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { +template +struct BroadcastDataInit { + template + static __device__ __forceinline__ void Apply(ArgsT *args) { using Type = std::tuple_element_t; - uint32_t index_bc[VecSize]; #pragma unroll for (int k = 0; k < VecSize; ++k) { - index_bc[k] = 0; std::get(args[k]) = static_cast(1); } + } +}; - uint32_t thread_offset = block_offset + threadIdx.x * VecSize; -#pragma unroll - for (int k = 0; k < VecSize; ++k) { - uint32_t idx = thread_offset + k; - if (IsBoundary && idx == numel) { - break; - } -#pragma unroll - for (int i = 0; i < phi::DDim::kMaxRank; ++i) { - if (i == configs[0].rank) break; - auto fast_divmoder = configs[0].divmoders[i].Divmod(idx); - idx = fast_divmoder.val[0]; - index_bc[k] += fast_divmoder.val[1] * configs[Index].strides[i]; - } - } - +template +struct BroadcastDataSetter { + template + static __device__ __forceinline__ void Apply(const Array &ins, + ArgsT *args, + uint32_t index_bc[][VecSize]) { + using Type = std::tuple_element_t; #pragma unroll for (int k = 0; k < VecSize; ++k) { std::get(args[k]) = - reinterpret_cast(ins[Index])[index_bc[k]]; + reinterpret_cast(ins[Index])[index_bc[Index][k]]; } } }; @@ -285,8 +269,30 @@ __device__ void VectorizedBroadcastKernelImpl( __simd__ ArgsT args[VecSize]; __simd__ ConditionalT result[VecSize]; - BcUnroller::step( - ins, args, configs, use_broadcast, block_offset, num, numel); + if (LoadType == kBroadcast) { + uint32_t index_bc[Arity][VecSize] = {0}; + Unroller::step(args); + uint32_t thread_offset = block_offset + threadIdx.x * VecSize; +#pragma unroll + for (int k = 0; k < VecSize; ++k) { + uint32_t idx = thread_offset + k; + if (IsBoundary && idx == numel) break; +#pragma unroll + for (int i = 0; i < phi::DDim::kMaxRank; ++i) { + if (i == configs[0].rank) break; + auto fast_divmoder = configs[0].divmoders[i].Divmod(idx); + idx = fast_divmoder.val[0]; +#pragma unroll + for (int j = 0; j < Arity; ++j) { + index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i]; + } + } + } + Unroller::step(ins, args, index_bc); + } else { + BcUnroller::step( + ins, args, configs, use_broadcast, block_offset, num, numel); + } SameDimsElementwisePrimitiveCaller, VecSize, @@ -783,11 +789,7 @@ struct LaunchBroadcastKernelWithInt64IndexHelper +template void BroadcastKernelForDifferentVecSize( const KPDevice &ctx, const std::vector &ins, @@ -922,16 +924,12 @@ void BroadcastKernelForDifferentVecSize( } } -template +template void BroadcastKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, - int axis, - Functor func) { + Functor func, + int axis = -1) { // When there are multiple inputs, the outputs's rank should be equal the // maximum rank of all inputs. using Traits = phi::funcs::FunctionTraits; @@ -968,7 +966,7 @@ void BroadcastKernel(const KPDevice &ctx, max_rank = std::max(max_rank, (*outs)[0]->dims().size()); } axis = axis == -1 ? max_rank - min_rank : axis; - BroadcastKernelForDifferentVecSize( + BroadcastKernelForDifferentVecSize( ctx, ins, outs, axis, func); } @@ -976,15 +974,14 @@ template void ElementwiseCompute(const GPUContext &dev_ctx, const DenseTensor &x, const DenseTensor &y, - int axis, Functor func, - DenseTensor *z) { + DenseTensor *z, + int axis = -1) { std::vector ins = {&x, &y}; std::vector outs = {z}; dev_ctx.template Alloc(z); - BroadcastKernel( - dev_ctx, ins, &outs, axis, func); + BroadcastKernel(dev_ctx, ins, &outs, func, axis); } template (z); - funcs::ElementwiseCompute(dev_ctx, x, y, axis, Functor(), z); + funcs::ElementwiseCompute(dev_ctx, x, y, Functor(), z, axis); } #else @@ -1017,10 +1014,10 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx, auto y_dims = y.dims(); dev_ctx.template Alloc(z); if (x_dims.size() >= y_dims.size()) { - funcs::ElementwiseCompute(dev_ctx, x, y, axis, Functor(), z); + funcs::ElementwiseCompute(dev_ctx, x, y, Functor(), z, axis); } else { funcs::ElementwiseCompute( - dev_ctx, x, y, axis, InverseFunctor(), z); + dev_ctx, x, y, InverseFunctor(), z, axis); } } #endif diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index 2ae9d3c02f6..48a7008463c 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -191,25 +191,19 @@ __global__ void VectorizedRandomGenerator(const size_t n, } template -__global__ void DropOutNdForwardKernel( - const size_t n, - uint64_t seed, - const float dropout_prob, - const T* src, - uint8_t* mask, - uint64_t increment, - size_t main_offset, - DstFunctor dst_functor, - MaskFunctor mask_functor, - T* y, - int64_t N, - kps::details::BroadcastConfig broadcast_config, - const uint64_t* seed_ptr) { +__global__ void VectorizedGeneratorMask(const size_t n, + uint64_t seed, + const float dropout_prob, + const T* src, + uint8_t* mask, + uint64_t increment, + size_t main_offset, + MaskFunctor mask_functor, + + const uint64_t* seed_ptr) { // Vectorized Generate Mask // kCount is 4 for curand_uniform4 is used - if (seed_ptr) { - seed = seed_ptr[0]; - } + if (seed_ptr) seed = seed_ptr[0]; constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); @@ -259,22 +253,6 @@ __global__ void DropOutNdForwardKernel( kps::WriteData( mask + fix, &mask_result[0], remainder); } - // Broadcast mask data and do elementwise operaiton with DstFunctor - CUDA_KERNEL_LOOP(i, N) { - uint32_t offset = 0u; - uint32_t idx = i; - // Use (j < phi::DDim::kMaxRank) conditiion rather than - // (j < broadcast_config.rank) for (#pragma unroll) -#pragma unroll - for (int j = 0; j < phi::DDim::kMaxRank; ++j) { - if (j == broadcast_config.rank) break; - auto fast_divmoder = broadcast_config.divmoders[j].Divmod(idx); - idx = fast_divmoder.val[0]; - offset += broadcast_config.strides[j] * fast_divmoder.val[1]; - } - __syncthreads(); - y[i] = dst_functor(src[i], mask[offset]); - } } template @@ -348,18 +326,6 @@ void DropoutFwGPUKernelDriver( size / (block_size * kVecSize) * (block_size * kVecSize); if (is_dropout_nd) { - auto dst_functor = - DstFunctor(1.0f - dropout_prob, upscale_in_train, x_numel); - - std::vector out_dims = - std::move(phi::vectorize(x.dims())); - std::vector in_dims = - std::move(phi::vectorize(mask->dims())); - std::reverse(out_dims.begin(), out_dims.end()); - std::reverse(in_dims.begin(), in_dims.end()); - kps::details::BroadcastConfig broadcast_config( - out_dims, in_dims, x.dims().size()); - auto mask_functor = MaskFunctor(1.0f - dropout_prob); bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx, seed, @@ -372,20 +338,22 @@ void DropoutFwGPUKernelDriver( const uint64_t* seed_ptr = copy_in_kernel ? seed->data() : nullptr; - DropOutNdForwardKernel + VectorizedGeneratorMask <<>>(size, seed_data, dropout_prob, x_data, mask_data, + increment, main_offset, - dst_functor, mask_functor, - y_data, - y->numel(), - broadcast_config, seed_ptr); + auto dst_functor = + DstFunctor(1.0f - dropout_prob, upscale_in_train, x_numel); + std::vector ins = {&x, mask}; + std::vector outs = {y}; + phi::funcs::BroadcastKernel(dev_ctx, ins, &outs, dst_functor); } else { bool copy_in_kernel = GetSeedDataAndIncrement( dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment); @@ -469,30 +437,13 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, MT factor = upscale_in_train ? static_cast(1.0f / (1.0f - dropout_prob)) : static_cast(1.0f); + + std::vector ins = {&grad_y, &mask}; + std::vector outs = {grad_x}; if (is_dropout_nd) { - phi::DenseTensor broadcasted_mask; - - broadcasted_mask.Resize(grad_y.dims()); - dev_ctx.template Alloc(&broadcasted_mask); - - std::vector broadcast_ins = {&mask}; - std::vector broadcast_outs = {&broadcasted_mask}; - phi::funcs::BroadcastKernel(dev_ctx, - broadcast_ins, - &broadcast_outs, - -1, - kps::IdentityFunctor()); - - std::vector ins = {&grad_y, &broadcasted_mask}; - std::vector outs = {grad_x}; - phi::funcs::ElementwiseKernel( + phi::funcs::BroadcastKernel( dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); - } else { - std::vector ins = {&grad_y, &mask}; - std::vector outs = {grad_x}; phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); } diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 0f6c225d89b..274ac1cc32c 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -35,7 +35,6 @@ namespace kps = phi::kps; namespace phi { -enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 }; /* Packing scalar type T(float, int etc.) into Array type for supporting multiple-output feature in elementwise system.*/ template @@ -369,9 +368,9 @@ template void ElementwiseCompute(const CPUContext &dev_ctx, const DenseTensor &x, const DenseTensor &y, - int axis, Functor func, - DenseTensor *z) { + DenseTensor *z, + int axis = -1) { dev_ctx.Alloc(z); auto x_dims = x.dims(); auto y_dims = y.dims(); diff --git a/paddle/phi/kernels/fusion/gpu/attn_gemm.h b/paddle/phi/kernels/fusion/gpu/attn_gemm.h index 01544436e4d..a96601dddac 100644 --- a/paddle/phi/kernels/fusion/gpu/attn_gemm.h +++ b/paddle/phi/kernels/fusion/gpu/attn_gemm.h @@ -112,8 +112,8 @@ class AttnMatMul { // bias_out = output + bias std::vector ins = {output, bias}; std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); } } diff --git a/paddle/phi/kernels/fusion/gpu/fmha_ref.h b/paddle/phi/kernels/fusion/gpu/fmha_ref.h index 207be7b5d7b..a41582c7076 100644 --- a/paddle/phi/kernels/fusion/gpu/fmha_ref.h +++ b/paddle/phi/kernels/fusion/gpu/fmha_ref.h @@ -258,12 +258,11 @@ class FMHARef { ins.emplace_back(src_mask_tensor); outs.emplace_back(src_mask_out_tensor); int elewise_add_axis = -1; - phi::funcs::BroadcastKernel( - dev_ctx_, - ins, - &outs, - elewise_add_axis, - phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel(dev_ctx_, + ins, + &outs, + phi::funcs::AddFunctor(), + elewise_add_axis); phi::SoftmaxForwardCUDAKernelDriver( dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); @@ -435,12 +434,11 @@ class FMHARef { ins.emplace_back(src_mask_tensor); outs.emplace_back(src_mask_out_tensor); int elewise_add_axis = -1; - phi::funcs::BroadcastKernel( - dev_ctx_, - ins, - &outs, - elewise_add_axis, - phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel(dev_ctx_, + ins, + &outs, + phi::funcs::AddFunctor(), + elewise_add_axis); phi::SoftmaxForwardCUDAKernelDriver( dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); diff --git a/paddle/phi/kernels/gpu/dirichlet_kernel.cu b/paddle/phi/kernels/gpu/dirichlet_kernel.cu index eacbab80057..09d6a402e70 100644 --- a/paddle/phi/kernels/gpu/dirichlet_kernel.cu +++ b/paddle/phi/kernels/gpu/dirichlet_kernel.cu @@ -106,8 +106,8 @@ struct DirichletSampler { {new_shape.size() - 1}, true, false); - funcs::ElementwiseCompute, T, T>( - dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor(), out); + funcs::ElementwiseCompute, T>( + dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor(), out); } }; } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_divide_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_divide_grad_kernel.cu index 57bf6da4060..58cdc63b4a3 100644 --- a/paddle/phi/kernels/gpu/elementwise_divide_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_divide_grad_kernel.cu @@ -37,22 +37,21 @@ void DivideGradKernel(const Context& dev_ctx, const auto place = dev_ctx.GetPlace(); if (dx != nullptr && dy != nullptr) { std::vector ins = {&dout, &out, &y}; - GetGradXAndYOut( - dev_ctx, - place, - axis, - ins, - dout, - dx, - dy, - funcs::DivGradXYFunctor()); + GetGradXAndYOut(dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::DivGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { std::vector ins = {&dout, &y}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor()); } else if (dy != nullptr && dx == nullptr) { std::vector ins = {&dout, &out, &y}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor()); } } diff --git a/paddle/phi/kernels/gpu/elementwise_grad.h b/paddle/phi/kernels/gpu/elementwise_grad.h index 8440ed2b122..6152425f272 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad.h +++ b/paddle/phi/kernels/gpu/elementwise_grad.h @@ -35,7 +35,7 @@ void ReduceWrapper(const GPUContext &dev_ctx, dev_ctx, *src, dst, kps::IdentityFunctor(), reduce_dims); } -template +template void GetGradXAndYOut(const GPUContext &dev_ctx, const Place &place, int axis, @@ -67,8 +67,7 @@ void GetGradXAndYOut(const GPUContext &dev_ctx, outs = {&tmp_dx, &tmp_dy}; } - funcs::BroadcastKernel( - dev_ctx, ins, &outs, axis, func); + funcs::BroadcastKernel(dev_ctx, ins, &outs, func, axis); if (dx->dims() != dout.dims() && dy->dims() == dout.dims()) { ReduceWrapper(dev_ctx, axis, &tmp_dx, dx); @@ -80,7 +79,7 @@ void GetGradXAndYOut(const GPUContext &dev_ctx, } } -template +template void GetGradXOrYOut(const GPUContext &dev_ctx, const Place &place, int axis, @@ -100,7 +99,7 @@ void GetGradXOrYOut(const GPUContext &dev_ctx, outs = {dxy}; } - funcs::BroadcastKernel(dev_ctx, ins, &outs, axis, func); + funcs::BroadcastKernel(dev_ctx, ins, &outs, func, axis); if (dxy->dims() != dout.dims()) { ReduceWrapper(dev_ctx, axis, &tmp_dxy, dxy); } @@ -342,22 +341,21 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx, const auto place = dev_ctx.GetPlace(); if (dx != nullptr && dy != nullptr) { std::vector ins = {&dout, &out, &y}; - GetGradXAndYOut( - dev_ctx, - place, - axis, - ins, - dout, - dx, - dy, - funcs::DivGradXYFunctor()); + GetGradXAndYOut(dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::DivGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { std::vector ins = {&dout, &y}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor()); } else if (dy != nullptr && dx == nullptr) { std::vector ins = {&dout, &out, &y}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor()); } } @@ -380,22 +378,21 @@ void ElementwiseMulGrad(const GPUContext &dev_ctx, if (dx != nullptr && dy != nullptr) { std::vector ins = {&dout, &y, &x}; - GetGradXAndYOut( - dev_ctx, - place, - axis, - ins, - dout, - dx, - dy, - funcs::MultiplyGradXYFunctor()); + GetGradXAndYOut(dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::MultiplyGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { std::vector ins = {&dout, &y}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dx, funcs::MultiplyGradFunctor()); } else if (dx == nullptr && dy != nullptr) { std::vector ins = {&dout, &x}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dy, funcs::MultiplyGradFunctor()); } } diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index b69434c82da..efd15adc383 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -37,22 +37,21 @@ void MaximumGradKernel(const Context& dev_ctx, int axis = -1; if (dx != nullptr && dy != nullptr) { std::vector ins = {&x, &y, &dout}; - GetGradXAndYOut( - dev_ctx, - place, - axis, - ins, - dout, - dx, - dy, - funcs::MaxGradXYFunctor()); + GetGradXAndYOut(dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::MaxGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { std::vector ins = {&x, &y, &dout}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dx, funcs::MaxGradXFunctor()); } else if (dy != nullptr && dx == nullptr) { std::vector ins = {&x, &y, &dout}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dy, funcs::MaxGradYFunctor()); } } @@ -68,22 +67,21 @@ void MinimumGradKernel(const Context& dev_ctx, int axis = -1; if (dx != nullptr && dy != nullptr) { std::vector ins = {&x, &y, &dout}; - GetGradXAndYOut( - dev_ctx, - place, - axis, - ins, - dout, - dx, - dy, - funcs::MinGradXYFunctor()); + GetGradXAndYOut(dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::MinGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { std::vector ins = {&x, &y, &dout}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dx, funcs::MinGradXFunctor()); } else if (dy != nullptr && dx == nullptr) { std::vector ins = {&x, &y, &dout}; - GetGradXOrYOut( + GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dy, funcs::MinGradYFunctor()); } } diff --git a/paddle/phi/kernels/gpu/expand_as_kernel.cu b/paddle/phi/kernels/gpu/expand_as_kernel.cu index f87a2417ced..de47024f295 100644 --- a/paddle/phi/kernels/gpu/expand_as_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_as_kernel.cu @@ -74,8 +74,7 @@ void ExpandAsKernel(const Context& ctx, ctx.template Alloc(out); std::vector ins = {&x}; std::vector outs = {out}; - phi::funcs::BroadcastKernel( - ctx, ins, &outs, -1, kps::IdentityFunctor()); + phi::funcs::BroadcastKernel(ctx, ins, &outs, kps::IdentityFunctor()); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/expand_kernel.cu b/paddle/phi/kernels/gpu/expand_kernel.cu index b2f973b0a88..456aa9b3c5a 100644 --- a/paddle/phi/kernels/gpu/expand_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_kernel.cu @@ -73,8 +73,7 @@ void ExpandKernel(const Context& ctx, ctx.template Alloc(out); std::vector ins = {&x}; std::vector outs = {out}; - phi::funcs::BroadcastKernel( - ctx, ins, &outs, -1, kps::IdentityFunctor()); + phi::funcs::BroadcastKernel(ctx, ins, &outs, kps::IdentityFunctor()); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu index 620341f338e..e4ee1f34213 100644 --- a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu +++ b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu @@ -407,11 +407,10 @@ void MatrixRankTolKernel(const Context& dev_ctx, tol_tensor.Resize(dim_out); dev_ctx.template Alloc(&tol_tensor); - funcs::ElementwiseCompute, T, T>( + funcs::ElementwiseCompute, T>( dev_ctx, atol_tensor, rtol_tensor, - -1, GreaterElementFunctor(), &tol_tensor); @@ -421,12 +420,10 @@ void MatrixRankTolKernel(const Context& dev_ctx, compare_result.Resize(detail::NewAxisDim(dim_out, k)); dev_ctx.template Alloc(&compare_result); - int axis = -1; funcs::ElementwiseCompute, T, int64_t>( dev_ctx, eigenvalue_tensor, tol_tensor, - axis, funcs::GreaterThanFunctor(), &compare_result); diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h index 5ba779d027a..04befb29b2d 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -78,8 +78,8 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, // 1. equal_out = Equal(x, y) std::vector equal_inputs = {&new_y, new_in_tensor}; std::vector equal_outputs = {&equal_out_tensor}; - funcs::BroadcastKernel( - dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor()); + funcs::BroadcastKernel( + dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor(), 0); // 2. equal_count = reduceSum(equal_out) using MPType = typename kps::details::MPTypeTrait::Type; phi::funcs:: @@ -95,15 +95,15 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, std::vector mul_inputs = {&new_dout, &equal_out_tensor}; std::vector mul_outputs = {&equal_out_tensor}; - funcs::BroadcastKernel( - dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor()); + funcs::BroadcastKernel( + dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor(), 0); // 4. dx = Div(dx, equal_out) std::vector grad_inputs = {&equal_out_tensor, equal_count}; std::vector grad_outputs = {new_dx_tensor}; - funcs::BroadcastKernel( - dev_ctx, grad_inputs, &grad_outputs, 0, funcs::DivideFunctor()); + funcs::BroadcastKernel( + dev_ctx, grad_inputs, &grad_outputs, funcs::DivideFunctor(), 0); delete equal_out; delete equal_count; } diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index 01f91924645..7e01c1ae843 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -28,7 +28,7 @@ namespace phi { -template +template void ReduceGrad(const GPUContext& dev_ctx, DenseTensor* d_out, DenseTensor* d_x, @@ -36,14 +36,13 @@ void ReduceGrad(const GPUContext& dev_ctx, Functor functor) { std::vector inputs = {d_out}; std::vector outputs = {d_x}; - PD_VISIT_ALL_TYPES( - out_dtype, "BroadcastKernel", ([&] { - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, 0, functor); - })); + PD_VISIT_ALL_TYPES(out_dtype, "BroadcastKernel", ([&] { + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, functor, 0); + })); } -template +template void ReduceGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, @@ -79,8 +78,7 @@ void ReduceGradKernel(const Context& dev_ctx, auto pt_d_x = *d_x; std::vector inputs = {&pt_d_out}; std::vector outputs = {&pt_d_x}; - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, 0, functor); + funcs::BroadcastKernel(dev_ctx, inputs, &outputs, functor, 0); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_max_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_max_grad_kernel.cu index 7b4472c5223..6bee38abe1f 100644 --- a/paddle/phi/kernels/gpu/reduce_max_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_max_grad_kernel.cu @@ -62,14 +62,14 @@ void ReduceMaxGradKernel(const Context& dev_ctx, // 1. equal_out = Equal(x, y) std::vector equal_inputs = {&new_out, &x}; std::vector equal_outputs = {equal_out}; - funcs::BroadcastKernel( - dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor()); + funcs::BroadcastKernel( + dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor(), 0); // 2. dx = dout * 1 std::vector mul_inputs = {&new_out_grad, equal_out}; std::vector mul_outputs = {x_grad}; - funcs::BroadcastKernel( - dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor()); + funcs::BroadcastKernel( + dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor(), 0); delete equal_out; } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index 6519f6b0855..0eac15902af 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -53,8 +53,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx, std::vector outputs = {x_grad}; using MPType = typename kps::details::MPTypeTrait::Type; - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, 0, kps::DivideFunctor(reduce_num)); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, kps::DivideFunctor(reduce_num), 0); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_min_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_min_grad_kernel.cu index 86cccc5e03b..7a650dc9640 100644 --- a/paddle/phi/kernels/gpu/reduce_min_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_min_grad_kernel.cu @@ -62,14 +62,14 @@ void ReduceMinGradKernel(const Context& dev_ctx, // 1. equal_out = Equal(x, y) std::vector equal_inputs = {&new_out, &x}; std::vector equal_outputs = {equal_out}; - funcs::BroadcastKernel( - dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor()); + funcs::BroadcastKernel( + dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor(), 0); // 2. dx = dout * 1 std::vector mul_inputs = {&new_out_grad, equal_out}; std::vector mul_outputs = {x_grad}; - funcs::BroadcastKernel( - dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor()); + funcs::BroadcastKernel( + dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor(), 0); delete equal_out; } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 15215c05d63..9ee6d530374 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -48,7 +48,7 @@ void ReduceSumGradKernel(const Context& dev_ctx, // call ReduceGrad dev_ctx.Alloc(x_grad, x.dtype()); using MPType = typename kps::details::MPTypeTrait::Type; - phi::ReduceGrad>( + phi::ReduceGrad>( dev_ctx, &new_out_grad, x_grad, diff --git a/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu index 4557d44f150..1d491b41ecb 100644 --- a/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu @@ -46,8 +46,7 @@ void SquaredL2NormGradKernel(const Context& dev_ctx, std::vector ins{&x, &dout}; std::vector outs{dx}; - funcs::BroadcastKernel( - dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor()); + funcs::BroadcastKernel(dev_ctx, ins, &outs, phi::DoubleMulFunctor()); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/tile_kernel.cu b/paddle/phi/kernels/gpu/tile_kernel.cu index be825eea499..7861a2bdf01 100644 --- a/paddle/phi/kernels/gpu/tile_kernel.cu +++ b/paddle/phi/kernels/gpu/tile_kernel.cu @@ -78,8 +78,8 @@ void TileKernel(const Context& dev_ctx, tmp_out.Resize(make_ddim(vec_x_dims)); dev_ctx.template Alloc(&tmp_out); std::vector outs = {&tmp_out}; - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, i, kps::IdentityFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, kps::IdentityFunctor(), i); tmp_out.Resize(out_dims); new_x = tmp_out; } @@ -89,8 +89,8 @@ void TileKernel(const Context& dev_ctx, out->Resize(make_ddim(vec_x_dims)); dev_ctx.template Alloc(out); std::vector outs = {out}; - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, i, kps::IdentityFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, kps::IdentityFunctor(), i); out->Resize(out_dims); } } diff --git a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu index 0a78df87a8c..be630f85ce0 100644 --- a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu +++ b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu @@ -91,8 +91,7 @@ struct BinaryOperation { DenseTensor* output) { std::vector ins{&lhs, &rhs}; std::vector outs{output}; - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, BinaryFunctor()); + phi::funcs::BroadcastKernel(dev_ctx, ins, &outs, BinaryFunctor(), 0); } }; diff --git a/paddle/phi/kernels/impl/complex_kernel_impl.h b/paddle/phi/kernels/impl/complex_kernel_impl.h index 8bd78234119..ebbbda04a01 100644 --- a/paddle/phi/kernels/impl/complex_kernel_impl.h +++ b/paddle/phi/kernels/impl/complex_kernel_impl.h @@ -90,16 +90,16 @@ void ComplexKernel(const Context& dev_ctx, // facility functions #if defined(__NVCC__) || defined(__HIPCC__) phi::funcs::ElementwiseCompute, T, C>( - dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), out); + dev_ctx, x, y, RealAndImagToComplexFunctor(), out); #else auto x_dims = x.dims(); auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { phi::funcs::ElementwiseCompute, T, C>( - dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), out); + dev_ctx, x, y, RealAndImagToComplexFunctor(), out); } else { phi::funcs::ElementwiseCompute, T, C>( - dev_ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor(), out); + dev_ctx, x, y, ImagAndRealToComplexFunctor(), out); } #endif } diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 1d069775f22..15f99a58fa5 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -76,15 +76,15 @@ void AddDoubleGradImpl(const Context& dev_ctx, auto ddy_dims = ddy_safe.dims(); if (ddx_dims.size() >= ddy_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor(), ddout); + dev_ctx, ddx_safe, ddy_safe, funcs::AddFunctor(), ddout, axis); } else { funcs::ElementwiseCompute, T>( dev_ctx, ddx_safe, ddy_safe, - axis, funcs::InverseAddFunctor(), - ddout); + ddout, + axis); } } } @@ -107,7 +107,7 @@ void SubtractDoubleGradImpl(const Context& dev_ctx, dev_ctx.template Alloc(ddout); funcs::ElementwiseCompute, T>( - dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor(), ddout); + dev_ctx, ddx_safe, ddy_safe, funcs::SubtractFunctor(), ddout, axis); } } diff --git a/paddle/phi/kernels/impl/elementwise_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_kernel_impl.h index 5a30fec36c3..0121f35b3ce 100644 --- a/paddle/phi/kernels/impl/elementwise_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_kernel_impl.h @@ -39,10 +39,10 @@ namespace phi { auto y_dims = y.dims(); \ if (x_dims.size() >= y_dims.size()) { \ funcs::ElementwiseCompute, T>( \ - dev_ctx, x, y, axis, funcs::name##Functor(), out); \ + dev_ctx, x, y, funcs::name##Functor(), out, axis); \ } else { \ funcs::ElementwiseCompute, T>( \ - dev_ctx, x, y, axis, funcs::Inverse##name##Functor(), out); \ + dev_ctx, x, y, funcs::Inverse##name##Functor(), out, axis); \ } \ } \ } @@ -62,8 +62,8 @@ namespace phi { inputs.emplace_back(&y); \ outputs.emplace_back(out); \ dev_ctx.template Alloc(out); \ - funcs::BroadcastKernel( \ - dev_ctx, inputs, &outputs, axis, funcs::name##Functor()); \ + funcs::BroadcastKernel( \ + dev_ctx, inputs, &outputs, funcs::name##Functor(), axis); \ } template @@ -72,8 +72,8 @@ void FMaxKernel(const Context& dev_ctx, const DenseTensor& y, DenseTensor* out) { dev_ctx.template Alloc(out); - funcs::ElementwiseCompute, T, T>( - dev_ctx, x, y, -1, funcs::FMaxFunctor(), out); + funcs::ElementwiseCompute, T>( + dev_ctx, x, y, funcs::FMaxFunctor(), out); } template @@ -82,8 +82,8 @@ void FMinKernel(const Context& dev_ctx, const DenseTensor& y, DenseTensor* out) { dev_ctx.template Alloc(out); - funcs::ElementwiseCompute, T, T>( - dev_ctx, x, y, -1, funcs::FMinFunctor(), out); + funcs::ElementwiseCompute, T>( + dev_ctx, x, y, funcs::FMinFunctor(), out); } } // namespace phi diff --git a/paddle/phi/kernels/impl/lu_kernel_impl.h b/paddle/phi/kernels/impl/lu_kernel_impl.h index 5315e36b471..5663484362a 100644 --- a/paddle/phi/kernels/impl/lu_kernel_impl.h +++ b/paddle/phi/kernels/impl/lu_kernel_impl.h @@ -153,12 +153,8 @@ void SetValueCompute(const Context& dev_ctx, slice_tensor.Resize(slice_dims_for_assign); if (value_tensor != nullptr) { CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims()); - phi::funcs::ElementwiseCompute, T, T>(dev_ctx, - slice_tensor, - *value_tensor, - -1, - SubFunctor(), - &slice_tensor); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, slice_tensor, *value_tensor, SubFunctor(), &slice_tensor); } else { DenseTensor value_t(dtype); auto value_dims = phi::make_ddim(shape); @@ -166,8 +162,8 @@ void SetValueCompute(const Context& dev_ctx, value_t.Resize(value_dims); dev_ctx.template Alloc(&value_t); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, slice_tensor, value_t, -1, SubFunctor(), &slice_tensor); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, slice_tensor, value_t, SubFunctor(), &slice_tensor); } slice_tensor.Resize(slice_dims); diff --git a/paddle/phi/kernels/impl/set_value_kernel_impl.h b/paddle/phi/kernels/impl/set_value_kernel_impl.h index 9956260a798..2c545ac06ad 100644 --- a/paddle/phi/kernels/impl/set_value_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_kernel_impl.h @@ -204,7 +204,6 @@ void SetValueImpl(const Context& dev_ctx, dev_ctx, slice_tensor, value, - -1, funcs::SubtractFunctor(), &slice_tensor); } else { @@ -212,7 +211,6 @@ void SetValueImpl(const Context& dev_ctx, dev_ctx, slice_tensor, value, - -1, funcs::InverseSubtractFunctor(), &slice_tensor); } diff --git a/paddle/phi/kernels/kps/bitwise_kernel.cu b/paddle/phi/kernels/kps/bitwise_kernel.cu index 285b18927af..fcdc7c95e91 100644 --- a/paddle/phi/kernels/kps/bitwise_kernel.cu +++ b/paddle/phi/kernels/kps/bitwise_kernel.cu @@ -25,18 +25,17 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/broadcast_function.h" namespace phi { -#define DEFINE_BITWISE_KERNEL(op_type) \ - template \ - void Bitwise##op_type##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - DenseTensor* out) { \ - dev_ctx.template Alloc(out); \ - funcs::Bitwise##op_type##Functor func; \ - std::vector ins = {&x, &y}; \ - std::vector outs = {out}; \ - funcs::BroadcastKernel( \ - dev_ctx, ins, &outs, -1, func); \ +#define DEFINE_BITWISE_KERNEL(op_type) \ + template \ + void Bitwise##op_type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + dev_ctx.template Alloc(out); \ + funcs::Bitwise##op_type##Functor func; \ + std::vector ins = {&x, &y}; \ + std::vector outs = {out}; \ + funcs::BroadcastKernel(dev_ctx, ins, &outs, func); \ } DEFINE_BITWISE_KERNEL(And) diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index 5397bf41efb..50de82cd004 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -55,8 +55,7 @@ inline void CompareKernelImpl(const Context& ctx, ctx.template Alloc(out); std::vector ins{&x, &y}; std::vector outs{out}; - funcs::BroadcastKernel( - ctx, ins, &outs, axis, Functor()); + funcs::BroadcastKernel(ctx, ins, &outs, Functor(), axis); } #ifndef PADDLE_WITH_XPU_KP diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index 550e579ac11..e56cb4d8459 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -72,8 +72,8 @@ void HeavisideKernel(const Context& dev_ctx, inputs.emplace_back(&y); outputs.emplace_back(out); dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, -1, funcs::ElementwiseHeavisideFunctor()); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::ElementwiseHeavisideFunctor()); } template diff --git a/paddle/phi/kernels/kps/logical_kernel.cu b/paddle/phi/kernels/kps/logical_kernel.cu index 570ac0e676e..cd9e12fe367 100755 --- a/paddle/phi/kernels/kps/logical_kernel.cu +++ b/paddle/phi/kernels/kps/logical_kernel.cu @@ -25,20 +25,17 @@ namespace phi { -#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ - template \ - void Logical##type##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - DenseTensor* out) { \ - using InT = typename funcs::Logical##type##Functor::ELEMENT_TYPE; \ - using OutT = bool; \ - dev_ctx.template Alloc(out); \ - funcs::Logical##type##Functor binary_func; \ - std::vector ins = {&x, &y}; \ - std::vector outs = {out}; \ - funcs::BroadcastKernel( \ - dev_ctx, ins, &outs, -1, binary_func); \ +#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ + template \ + void Logical##type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + dev_ctx.template Alloc(out); \ + funcs::Logical##type##Functor binary_func; \ + std::vector ins = {&x, &y}; \ + std::vector outs = {out}; \ + funcs::BroadcastKernel(dev_ctx, ins, &outs, binary_func); \ } DEFINE_LOGICAL_BINARY_KERNEL(And) @@ -50,15 +47,11 @@ template void LogicalNotKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { - using InT = typename funcs::LogicalNotFunctor::ELEMENT_TYPE; - using OutT = bool; - dev_ctx.template Alloc(out); funcs::LogicalNotFunctor unary_func; std::vector ins = {&x}; std::vector outs = {out}; - funcs::BroadcastKernel( - dev_ctx, ins, &outs, -1, unary_func); + funcs::BroadcastKernel(dev_ctx, ins, &outs, unary_func); } } // namespace phi diff --git a/paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc b/paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc index a976cb2a009..6d051863a8e 100644 --- a/paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc @@ -30,7 +30,7 @@ void MaximumRawKernel(const Context& dev_ctx, // allocate memory for out dev_ctx.template Alloc(out); funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::MaximumFunctor(), out); + dev_ctx, x, y, funcs::MaximumFunctor(), out, axis); } template @@ -42,7 +42,7 @@ void MinimumRawKernel(const Context& dev_ctx, // allocate memory for out dev_ctx.template Alloc(out); funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::MinimumFunctor(), out); + dev_ctx, x, y, funcs::MinimumFunctor(), out, axis); } template @@ -57,10 +57,10 @@ void RemainderRawKernel(const Context& dev_ctx, auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::RemainderFunctor(), out); + dev_ctx, x, y, funcs::RemainderFunctor(), out, axis); } else { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::InverseRemainderFunctor(), out); + dev_ctx, x, y, funcs::InverseRemainderFunctor(), out, axis); } } @@ -76,10 +76,10 @@ void FloorDivideRawKernel(const Context& dev_ctx, auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::FloorDivideFunctor(), out); + dev_ctx, x, y, funcs::FloorDivideFunctor(), out, axis); } else { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::InverseFloorDivideFunctor(), out); + dev_ctx, x, y, funcs::InverseFloorDivideFunctor(), out, axis); } } @@ -95,10 +95,10 @@ void ElementwisePowRawKernel(const Context& dev_ctx, auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::ElementwisePowFunctor(), out); + dev_ctx, x, y, funcs::ElementwisePowFunctor(), out, axis); } else { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::ElementwiseInversePowFunctor(), out); + dev_ctx, x, y, funcs::ElementwiseInversePowFunctor(), out, axis); } } diff --git a/paddle/phi/kernels/legacy/kps/elementwise_raw_kernel.cu b/paddle/phi/kernels/legacy/kps/elementwise_raw_kernel.cu index 95cf5d4333e..8fab237f0a5 100644 --- a/paddle/phi/kernels/legacy/kps/elementwise_raw_kernel.cu +++ b/paddle/phi/kernels/legacy/kps/elementwise_raw_kernel.cu @@ -36,8 +36,8 @@ void MaximumRawKernel(const Context& dev_ctx, inputs.emplace_back(&y); outputs.emplace_back(out); dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, axis, funcs::MaximumFunctor()); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::MaximumFunctor(), axis); } template @@ -54,8 +54,8 @@ void MinimumRawKernel(const Context& dev_ctx, inputs.emplace_back(&y); outputs.emplace_back(out); dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, axis, funcs::MinimumFunctor()); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::MinimumFunctor(), axis); } template @@ -72,8 +72,8 @@ void RemainderRawKernel(const Context& dev_ctx, inputs.emplace_back(&y); outputs.emplace_back(out); dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, axis, funcs::RemainderFunctor()); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::RemainderFunctor(), axis); } template @@ -90,8 +90,8 @@ void FloorDivideRawKernel(const Context& dev_ctx, inputs.emplace_back(&y); outputs.emplace_back(out); dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, axis, funcs::FloorDivideFunctor()); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::FloorDivideFunctor(), axis); } template @@ -108,8 +108,8 @@ void ElementwisePowRawKernel(const Context& dev_ctx, inputs.emplace_back(&y); outputs.emplace_back(out); dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, axis, funcs::ElementwisePowFunctor()); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::ElementwisePowFunctor(), axis); } } // namespace phi @@ -174,4 +174,5 @@ PD_REGISTER_KERNEL(elementwise_pow_raw, float16, int64_t, bfloat16) {} + #endif diff --git a/test/cpp/phi/kernels/test_ternary_broadcast.cu b/test/cpp/phi/kernels/test_ternary_broadcast.cu index d1faa89a1ed..09598e63790 100644 --- a/test/cpp/phi/kernels/test_ternary_broadcast.cu +++ b/test/cpp/phi/kernels/test_ternary_broadcast.cu @@ -89,8 +89,7 @@ void TestCase(const phi::GPUContext& dev_ctx, d_in1.get(), d_in2.get(), d_in3.get()}; std::vector outputs{d_out.get()}; for (int i = 0; i < times; ++i) { - phi::funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, -1, compute); + phi::funcs::BroadcastKernel(dev_ctx, inputs, &outputs, compute); } dev_ctx.Wait(); } -- GitLab