diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 7ca62b1969e261f9661ba0e2acaf1734ca809b37..94ce32cdaf182f281cfc666bd7d07fef38c3c167 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -19,17 +19,13 @@ namespace paddle { namespace operators { -template +template void LaunchElementwiseCudaKernel( const KPDevice &ctx, const std::vector &ins, std::vector *outs, - int axis, - Functor func) { + Functor func, + int axis = -1) { std::vector pt_inputs; std::vector pt_outputs; // TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary @@ -53,8 +49,8 @@ void LaunchElementwiseCudaKernel( for (int i = 0; i < pt_outputs_tmp.size(); i++) { pt_outputs.push_back(pt_outputs_tmp[i].get()); } - phi::funcs::BroadcastKernel( - ctx, pt_inputs, &pt_outputs, axis, func); + phi::funcs::BroadcastKernel( + ctx, pt_inputs, &pt_outputs, func, axis); } } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 251df9b9d9e4e3151ebfd72c670994be75b75613..ab721d35278b68aa04d4a45e50fb935626bb3ff1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -188,7 +188,7 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, z->mutable_data(ctx.GetPlace()); const auto &dev_ctx = ctx.template device_context(); phi::funcs::ElementwiseCompute( - dev_ctx, *x, *y, axis, func, z); + dev_ctx, *x, *y, func, z, axis); } // FusedElemwiseAndAct @@ -1596,7 +1596,7 @@ static inline std::vector GetReduceDim(const framework::DDim &in, #if defined(__NVCC__) || defined(__HIPCC__) -template +template void GetGradXAndYOut(const phi::GPUContext &dev_ctx, const platform::Place &place, int axis, @@ -1605,11 +1605,11 @@ void GetGradXAndYOut(const phi::GPUContext &dev_ctx, phi::DenseTensor *dx, phi::DenseTensor *dy, Functor func) { - phi::GetGradXAndYOut( + phi::GetGradXAndYOut( dev_ctx, place, axis, ins, *dout, dx, dy, func); } -template +template void GetGradXOrYOut(const phi::GPUContext &dev_ctx, const platform::Place &place, int axis, @@ -1617,8 +1617,7 @@ void GetGradXOrYOut(const phi::GPUContext &dev_ctx, const phi::DenseTensor *dout, phi::DenseTensor *dxy, Functor func) { - phi::GetGradXOrYOut( - dev_ctx, place, axis, ins, *dout, dxy, func); + phi::GetGradXOrYOut(dev_ctx, place, axis, ins, *dout, dxy, func); } #endif diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 2680c8ecc5e521bf8f91a83b252ed721e8bc6083..47e317831409abd10c7adb19c85a036393b55ebe 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -23,8 +23,6 @@ limitations under the License. */ namespace paddle { namespace operators { -using ElementwiseType = phi::ElementwiseType; - template void LaunchSameDimsElementwiseCudaKernel( const KPDevice &ctx, diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index 9709f60bbc1ce2c11694704ba8b80661dcfba434..277e29c4d59ce56cb7c3056ee85b277165a2cbdf 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -109,8 +109,8 @@ class AttnMatMul { // bias_out = output + bias std::vector ins = {output, bias}; std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); } } diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index c61a7f60d43599cfcc09a2dc647f934e52f7a88f..705cb8ece418e886c88a35334b8271b924a228fc 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -85,8 +85,8 @@ class AttnMatmulINT8 { // bias_out = output + bias std::vector ins = {output, bias}; std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess, platform::errors::Fatal( @@ -139,8 +139,8 @@ class AttnMatmulINT8 { // bias_out = output + bias std::vector ins = {output, bias}; std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess, platform::errors::Fatal( diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 1d83c7a62b1d94031c0b6bbde81dd5504a056068..843b5009a6fccd03b708db7fc07300b8df8828ca 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -255,12 +255,11 @@ class FMHARef { ins.emplace_back(src_mask_tensor); outs.emplace_back(src_mask_out_tensor); int elewise_add_axis = -1; - phi::funcs::BroadcastKernel( - dev_ctx_, - ins, - &outs, - elewise_add_axis, - phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel(dev_ctx_, + ins, + &outs, + phi::funcs::AddFunctor(), + elewise_add_axis); phi::SoftmaxForwardCUDAKernelDriver( dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); @@ -432,12 +431,11 @@ class FMHARef { ins.emplace_back(src_mask_tensor); outs.emplace_back(src_mask_out_tensor); int elewise_add_axis = -1; - phi::funcs::BroadcastKernel( - dev_ctx_, - ins, - &outs, - elewise_add_axis, - phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel(dev_ctx_, + ins, + &outs, + phi::funcs::AddFunctor(), + elewise_add_axis); phi::SoftmaxForwardCUDAKernelDriver( dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index c8c4733df2e2e53aa8a5e7afdb3120bde8a66d7a..105647baf1c35ff55cc24e8a82fd75f9652e32c0 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -689,13 +689,13 @@ class FMHAGateRef { std::vector ins = { qk_out, src_mask, nonbatched_bias}; std::vector outs = {qk_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, TernaryAddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, TernaryAddFunctor()); } else { std::vector ins = {qk_out, src_mask}; std::vector outs = {qk_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); + phi::funcs::BroadcastKernel( + dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); } phi::SoftmaxForwardCUDAKernelDriver(dev_ctx_, *qk_out, -1, softmax_out); } diff --git a/paddle/fluid/operators/fused_token_prune_op.cu b/paddle/fluid/operators/fused_token_prune_op.cu index 434c072e5aa6a3e66b6ddf96fa76f52206871999..8f0a53611f3b29f4e58b22138fac5193981865cc 100644 --- a/paddle/fluid/operators/fused_token_prune_op.cu +++ b/paddle/fluid/operators/fused_token_prune_op.cu @@ -141,8 +141,7 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel { ins.emplace_back(attn); ins.emplace_back(mask); outs.emplace_back(&attn_tmp); - LaunchElementwiseCudaKernel( - dev_ctx, ins, &outs, -1, AttnMaskFunctor()); + LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, AttnMaskFunctor()); // 2. Reduce sum const std::vector reduce_dims{1, 2}; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index f9b79010263088846b490425d90017caefaf82cb..aaa86f8c37f62629eb674ffe41c681c0a85c62f8 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -836,12 +836,11 @@ class ReduceCudaGradKernel : public framework::OpKernel { } using MPType = typename kps::details::MPTypeTrait::Type; - phi::ReduceGrad>( - dev_ctx, - pt_d_out.get(), - pt_d_x.get(), - pt_out_dtype, - TransformOp(reduce_num)); + phi::ReduceGrad>(dev_ctx, + pt_d_out.get(), + pt_d_x.get(), + pt_out_dtype, + TransformOp(reduce_num)); } }; diff --git a/paddle/phi/kernels/cpu/bitwise_kernel.cc b/paddle/phi/kernels/cpu/bitwise_kernel.cc index 80424ef624f61bb8f28de66cf122053ef1514c1d..a6297efd9cd3e284de3a0dcaa1e6b3007394a51a 100644 --- a/paddle/phi/kernels/cpu/bitwise_kernel.cc +++ b/paddle/phi/kernels/cpu/bitwise_kernel.cc @@ -24,15 +24,15 @@ limitations under the License. */ namespace phi { -#define DEFINE_BITWISE_KERNEL(op_type) \ - template \ - void Bitwise##op_type##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - DenseTensor* out) { \ - funcs::Bitwise##op_type##Functor func; \ - funcs::ElementwiseCompute, T, T>( \ - dev_ctx, x, y, -1, func, out); \ +#define DEFINE_BITWISE_KERNEL(op_type) \ + template \ + void Bitwise##op_type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + funcs::Bitwise##op_type##Functor func; \ + funcs::ElementwiseCompute, T>( \ + dev_ctx, x, y, func, out); \ } DEFINE_BITWISE_KERNEL(And) diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index cf8eb47fb427f6b73fc5b96592440678609086cd..0fd1332d76c641bc7779bbc8c1fd90ab08297378 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -33,10 +33,10 @@ inline void CompareKernelImpl(const Context& ctx, ctx.template Alloc(out); if (x.dims().size() >= y.dims().size()) { funcs::ElementwiseCompute( - ctx, x, y, axis, Functor(), out); + ctx, x, y, Functor(), out, axis); } else { funcs::ElementwiseCompute( - ctx, x, y, axis, InverseFunctor(), out); + ctx, x, y, InverseFunctor(), out, axis); } } @@ -59,7 +59,7 @@ inline void CompareAllKernelImpl(const Context& ctx, tmp_data[0] = Functor()(x.data()[0], y.data()[0]); } else { funcs::ElementwiseCompute( - ctx, x, y, 0, Functor(), &tmp); + ctx, x, y, Functor(), &tmp, 0); } auto tmp_flat = EigenVector::Flatten(tmp); auto out_es = EigenScalar::From(*out); diff --git a/paddle/phi/kernels/cpu/dirichlet_kernel.cc b/paddle/phi/kernels/cpu/dirichlet_kernel.cc index c124920dfa0db8af9cb85e5b4b5889b664dfe989..855e6bdfe1e1ff19fcf4f6e08616e6feef368300 100644 --- a/paddle/phi/kernels/cpu/dirichlet_kernel.cc +++ b/paddle/phi/kernels/cpu/dirichlet_kernel.cc @@ -91,8 +91,8 @@ struct DirichletSampler { true, false); - funcs::ElementwiseCompute, T, T>( - dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor(), out); + funcs::ElementwiseCompute, T>( + dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor(), out); } }; diff --git a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc index 40d0e863ea50e92484e80e9f005f7cf2059caeeb..8dbbfc13b81e2e4d578551fe991d6965e5f4297a 100644 --- a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc @@ -38,10 +38,10 @@ void DivideRawKernel(const Context& dev_ctx, auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::DivideFunctor(), out); + dev_ctx, x, y, funcs::DivideFunctor(), out, axis); } else { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::InverseDivideFunctor(), out); + dev_ctx, x, y, funcs::InverseDivideFunctor(), out, axis); } } } diff --git a/paddle/phi/kernels/cpu/elementwise_kernel.cc b/paddle/phi/kernels/cpu/elementwise_kernel.cc index 11aac8bbfe3ad37749d1098d81a977db6aaffd2e..321b439547e8d96c7fc906ba61882446199f1b05 100644 --- a/paddle/phi/kernels/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_kernel.cc @@ -30,7 +30,7 @@ void MaximumRawKernel(const Context& dev_ctx, // allocate memory for out dev_ctx.template Alloc(out); funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::MaximumFunctor(), out); + dev_ctx, x, y, funcs::MaximumFunctor(), out, axis); } template @@ -42,7 +42,7 @@ void MinimumRawKernel(const Context& dev_ctx, // allocate memory for out dev_ctx.template Alloc(out); funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::MinimumFunctor(), out); + dev_ctx, x, y, funcs::MinimumFunctor(), out, axis); } template @@ -57,10 +57,10 @@ void RemainderRawKernel(const Context& dev_ctx, auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::RemainderFunctor(), out); + dev_ctx, x, y, funcs::RemainderFunctor(), out, axis); } else { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::InverseRemainderFunctor(), out); + dev_ctx, x, y, funcs::InverseRemainderFunctor(), out, axis); } } @@ -76,10 +76,10 @@ void FloorDivideRawKernel(const Context& dev_ctx, auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::FloorDivideFunctor(), out); + dev_ctx, x, y, funcs::FloorDivideFunctor(), out, axis); } else { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::InverseFloorDivideFunctor(), out); + dev_ctx, x, y, funcs::InverseFloorDivideFunctor(), out, axis); } } @@ -95,10 +95,10 @@ void ElementwisePowRawKernel(const Context& dev_ctx, auto y_dims = y.dims(); if (x_dims.size() >= y_dims.size()) { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::ElementwisePowFunctor(), out); + dev_ctx, x, y, funcs::ElementwisePowFunctor(), out, axis); } else { funcs::ElementwiseCompute, T>( - dev_ctx, x, y, axis, funcs::ElementwiseInversePowFunctor(), out); + dev_ctx, x, y, funcs::ElementwiseInversePowFunctor(), out, axis); } } @@ -110,7 +110,7 @@ void HeavisideKernel(const Context& dev_ctx, // allocate memory for out dev_ctx.template Alloc(out); funcs::ElementwiseCompute, T>( - dev_ctx, x, y, -1, funcs::ElementwiseHeavisideFunctor(), out); + dev_ctx, x, y, funcs::ElementwiseHeavisideFunctor(), out); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc index c42e423ba2d34c000e396b9f1f623f46e60f975c..630e786b571bc7cfbc38cf87c87074e27ae01546 100644 --- a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc @@ -68,20 +68,15 @@ void LayerNormGradKernel(const Context& dev_ctx, temp_norm.Resize(matrix_shape); dev_ctx.template Alloc(&temp_norm); // get x_norm - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - x_tmp, - mean, - /*axis*/ 0, - funcs::SubtractFunctor(), - &temp_norm); - phi::funcs::ElementwiseCompute, T, T>( + phi::funcs::ElementwiseCompute, T>( + dev_ctx, x_tmp, mean, funcs::SubtractFunctor(), &temp_norm, 0); + phi::funcs::ElementwiseCompute, T>( dev_ctx, temp_norm, variance, - /*axis*/ 0, funcs::DivAndSqrtFunctor(static_cast(epsilon)), - &temp_norm); + &temp_norm, + 0); } if (d_bias) { @@ -90,8 +85,8 @@ void LayerNormGradKernel(const Context& dev_ctx, } if (d_scale) { dev_ctx.template Alloc(d_scale); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, temp_norm, d_y, 0, funcs::MultiplyFunctor(), &temp); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, temp_norm, d_y, funcs::MultiplyFunctor(), &temp, 0); colwise_sum(dev_ctx, temp, d_scale); } @@ -107,70 +102,45 @@ void LayerNormGradKernel(const Context& dev_ctx, if (d_scale) { // dy_dx - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, d_y, *scale, /*axis*/ 1, funcs::MultiplyFunctor(), &temp); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, d_y, *scale, funcs::MultiplyFunctor(), &temp, 1); phi::Copy(dev_ctx, temp, dev_ctx.GetPlace(), false, d_x); // dy_dmean_dx row_mean(dev_ctx, temp, &temp_vec); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - *d_x, - temp_vec, - /*axis*/ 0, - funcs::SubtractFunctor(), - d_x); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, *d_x, temp_vec, funcs::SubtractFunctor(), d_x, 0); // dy_var_dx - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - temp, - temp_norm, - /*axis*/ 0, - funcs::MultiplyFunctor(), - &temp); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, temp, temp_norm, funcs::MultiplyFunctor(), &temp, 0); } else { // dy_dx phi::Copy(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x); // dy_dmean_dx row_mean(dev_ctx, d_y, &temp_vec); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - *d_x, - temp_vec, - /*axis*/ 0, - funcs::SubtractFunctor(), - d_x); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, *d_x, temp_vec, funcs::SubtractFunctor(), d_x, 0); // dy_var_dx - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - d_y, - temp_norm, - /*axis*/ 0, - funcs::MultiplyFunctor(), - &temp); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, d_y, temp_norm, funcs::MultiplyFunctor(), &temp, 0); } // dy_var_dx row_mean(dev_ctx, temp, &temp_vec); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, - temp_norm, - temp_vec, - /*axis*/ 0, - funcs::MultiplyFunctor(), - &temp); - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, *d_x, temp, /*axis*/ 0, funcs::SubtractFunctor(), d_x); - - phi::funcs::ElementwiseCompute, T, T>( + phi::funcs::ElementwiseCompute, T>( + dev_ctx, temp_norm, temp_vec, funcs::MultiplyFunctor(), &temp, 0); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, *d_x, temp, funcs::SubtractFunctor(), d_x, 0); + + phi::funcs::ElementwiseCompute, T>( dev_ctx, *d_x, variance, - /*axis*/ 0, funcs::DivAndSqrtFunctor(static_cast(epsilon)), - d_x); + d_x, + 0); d_x->Resize(dx_dim); } } diff --git a/paddle/phi/kernels/cpu/layer_norm_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_kernel.cc index 1c82866f0bbda06ed35a8e9390c80c3d6305015d..2a93d03b4abc15a6a544426ea12ffe3a2320a2bb 100644 --- a/paddle/phi/kernels/cpu/layer_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_kernel.cc @@ -67,30 +67,30 @@ void LayerNormKernel(const Context& dev_ctx, // get variance - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, x_tmp, *mean, 0, funcs::SubAndSquareFunctor(), &out); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, x_tmp, *mean, funcs::SubAndSquareFunctor(), &out, 0); row_mean(dev_ctx, out, var); // get x_norm - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, x_tmp, *mean, 0, funcs::SubtractFunctor(), &out); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, x_tmp, *mean, funcs::SubtractFunctor(), &out, 0); - phi::funcs::ElementwiseCompute, T, T>( + phi::funcs::ElementwiseCompute, T>( dev_ctx, out, *var, - 0, funcs::DivAndSqrtFunctor(static_cast(epsilon)), - &out); + &out, + 0); if (scale) { - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, out, *scale, 1, funcs::MultiplyFunctor(), &out); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, out, *scale, funcs::MultiplyFunctor(), &out, 1); } if (bias) { - phi::funcs::ElementwiseCompute, T, T>( - dev_ctx, out, *bias, 1, funcs::AddFunctor(), &out); + phi::funcs::ElementwiseCompute, T>( + dev_ctx, out, *bias, funcs::AddFunctor(), &out, 1); } #else PADDLE_ENFORCE_EQ(mean->numel(), diff --git a/paddle/phi/kernels/cpu/logical_kernel.cc b/paddle/phi/kernels/cpu/logical_kernel.cc index 3a669883f78378db76d19103336870c3171b6126..38c927e976f8c874e2c18f17d8b0d6c09e57a9af 100644 --- a/paddle/phi/kernels/cpu/logical_kernel.cc +++ b/paddle/phi/kernels/cpu/logical_kernel.cc @@ -32,7 +32,7 @@ namespace phi { DenseTensor* out) { \ funcs::Logical##type##Functor binary_func; \ funcs::ElementwiseCompute, T, bool>( \ - dev_ctx, x, y, -1, binary_func, out); \ + dev_ctx, x, y, binary_func, out); \ } DEFINE_LOGICAL_BINARY_KERNEL(And) diff --git a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc index fbb16138567d154360acb33bb94ba695dd48ed0b..e1188fda486c77cdb6b588e857377e34b8d3b967 100644 --- a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc @@ -132,11 +132,10 @@ void MatrixRankTolKernel(const Context& dev_ctx, DenseTensor tol_tensor; tol_tensor.Resize(dim_out); dev_ctx.template Alloc(&tol_tensor); - funcs::ElementwiseCompute, T, T>( + funcs::ElementwiseCompute, T>( dev_ctx, atol_tensor, rtol_tensor, - -1, GreaterElementFunctor(), &tol_tensor); @@ -151,17 +150,17 @@ void MatrixRankTolKernel(const Context& dev_ctx, dev_ctx, eigenvalue_tensor, tol_tensor, - axis, funcs::GreaterThanFunctor(), - &compare_result); + &compare_result, + axis); } else { funcs::ElementwiseCompute, T, int>( dev_ctx, eigenvalue_tensor, tol_tensor, - axis, funcs::LessThanFunctor(), - &compare_result); + &compare_result, + axis); } phi::SumKernel(dev_ctx, diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index f96a1764c24a5e63db1fbf8e40e79535a5c4f309..e754ce3bf49e4659f885e9d94a116bc98ef0aa26 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -31,20 +31,49 @@ namespace funcs { enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 }; -template +template +struct UseBroadcast { + template + static HOSTDEVICE void Apply( + const std::vector &ins_tensor, + const ArgsT &args, + int64_t numel, + Array1 *ins_data, + Array2 *use_broadcast, + int *broadcast_num, + bool *all_elementwise) { + (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data()); + bool is_same_dim = ins_tensor[Index]->numel() == numel; + if (is_same_dim) { + (*use_broadcast)[Index] = false; + } else { + (*use_broadcast)[Index] = true; + (*broadcast_num)++; + } + *all_elementwise &= is_same_dim; + } +}; + +template struct LoaderTypeClassifier { public: int64_t numel{0}; - int vec_size{1}; + int vec_size{4}; int broadcast_num{0}; bool all_elementwise{true}; - phi::Array use_broadcast; - phi::Array ins_data; + phi::Array use_broadcast; + phi::Array ins_data; LoaderTypeClassifier() {} LoaderTypeClassifier(const std::vector &ins, std::vector *outs) { + using Traits = phi::funcs::FunctionTraits; + using ArgsT = typename Traits::ArgsTuple; + ArgsT arg; uint64_t out_addr = reinterpret_cast((*outs)[0]->data()); + + UnrollerWithoutVecSize::step(ins, arg, &vec_size); + for (auto i = 1; i < outs->size(); ++i) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), @@ -56,165 +85,191 @@ struct LoaderTypeClassifier { out_addr = (out_addr | reinterpret_cast((*outs)[i]->data())); } - int out_vec_size = - phi::GetVectorizedSize(reinterpret_cast(out_addr)); - uint64_t in_addr = static_cast(0); + vec_size = std::min( + vec_size, + phi::GetVectorizedSize(reinterpret_cast(out_addr))); numel = (*outs)[0]->numel(); - for (int i = 0; i < Arity; ++i) { - auto in_data = ins[i]->data(); - ins_data[i] = (const _ptr_ InT *)(in_data); - - bool is_same_dim = ins[i]->numel() == numel; - if (is_same_dim) { - use_broadcast[i] = false; - in_addr = (in_addr | reinterpret_cast(in_data)); - } else { - use_broadcast[i] = true; - broadcast_num++; - } - all_elementwise &= is_same_dim; - } - int in_vec_size = std::min( - 4, phi::GetVectorizedSize(reinterpret_cast(in_addr))); - vec_size = std::min(out_vec_size, in_vec_size); + UnrollerWithoutVecSize::step(ins, + arg, + numel, + &ins_data, + &use_broadcast, + &broadcast_num, + &all_elementwise); } }; -#ifndef PADDLE_WITH_XPU_KP // Common broadcast/elementwise Loader. -template +template struct BroadcastDataLoader { - __device__ __forceinline__ void operator()( - T args[Arity][VecSize], - const phi::Array &ins, - const phi::Array &configs, - const phi::Array &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { -#pragma unroll - for (int i = 0; i < Arity; ++i) { - kps::Init(args[i], static_cast(1.0f)); - if (use_broadcast[i]) { - kps::ReadDataBc( - args[i], ins[i], block_offset, configs[i], numel, VecSize); - } else { - kps::ReadData( - args[i], ins[i] + block_offset, num, VecSize); - } + template + static __device__ __forceinline__ void Apply(const Array1 &ins, + ArgsT *args, + const Array2 &configs, + const Array3 &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel, + int read_lens) { + using Type = std::tuple_element_t; +#ifdef PADDLE_WITH_XPU_KP + kps::Init( + args, static_cast(1.0f), read_lens); + if (use_broadcast[Index]) { + kps::ReadDataBc( + args, + reinterpret_cast(ins[Index]), + block_offset, + configs[Index], + numel, + read_lens); + } else { + kps::ReadData( + args, + reinterpret_cast(ins[Index]) + block_offset, + num, + read_lens); } +#else + kps::Init(args, static_cast(1.0f)); + if (use_broadcast[Index]) { + kps::ReadDataBc( + args, + reinterpret_cast(ins[Index]), + block_offset, + configs[Index], + numel, + VecSize); + } + // NOTE: If use if...else... with condition `use_broadcast[Index]` here, + // there will be some errs with clang12 while compiling in ROCm. + // When the compiler is upgraded, if...else... may be used. + if (!use_broadcast[Index]) { + kps::ReadData( + args, + reinterpret_cast(ins[Index]) + block_offset, + num, + VecSize); + } +#endif } }; +/* BroadcastDataLoaders Partial specialization */ +#ifndef PADDLE_WITH_XPU_KP // Scalar elementwise Loader with consideration of IsBoundary. -template -struct BroadcastDataLoader { - __device__ __forceinline__ void operator()( - T args[Arity][VecSize], - const phi::Array &ins, - const phi::Array &configs, - const phi::Array &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { +template +struct BroadcastDataLoader { + template + static __device__ __forceinline__ void Apply(const Array1 &ins, + ArgsT *args, + const Array2 &configs, + const Array3 &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel, + int read_lens) { + using Type = std::tuple_element_t; int thread_offset = threadIdx.x * VecSize + block_offset; #pragma unroll - for (int i = 0; i < Arity; ++i) { -#pragma unroll - for (int idx = 0; idx < VecSize; ++idx) { - args[i][idx] = static_cast(1); - int index = thread_offset + idx; - if (index < numel) { - args[i][idx] = ins[i][index]; - } + for (int idx = 0; idx < VecSize; ++idx) { + std::get(args[idx]) = static_cast(1); + int index = thread_offset + idx; + if (index < numel) { + std::get(args[idx]) = + reinterpret_cast(ins[Index])[index]; } } } }; // Vectorized elementwise Loader without consideration of IsBoundary. -template -struct BroadcastDataLoader { - __device__ __forceinline__ void operator()( - T args[Arity][VecSize], - const phi::Array &ins, - const phi::Array &configs, - const phi::Array &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { - using VecType = phi::kps::details::VectorType; - VecType vec_temp[Arity]; +template +struct BroadcastDataLoader { + template + static __device__ __forceinline__ void Apply(const Array1 &ins, + ArgsT *args, + const Array2 &configs, + const Array3 &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel, + int read_lens) { + using Type = std::tuple_element_t; + using VecType = phi::kps::details::VectorType; + VecType vec_temp; int thread_offset = threadIdx.x + blockIdx.x * blockDim.x; + const VecType *__restrict__ vec_input = + reinterpret_cast(ins[Index]); + vec_temp = vec_input[thread_offset]; #pragma unroll - for (int i = 0; i < Arity; ++i) { - const VecType *__restrict__ vec_input = - reinterpret_cast(ins[i]); - vec_temp[i] = vec_input[thread_offset]; -#pragma unroll - for (int idx = 0; idx < VecSize; ++idx) { - args[i][idx] = vec_temp[i].val[idx]; - } + for (int idx = 0; idx < VecSize; ++idx) { + std::get(args[idx]) = vec_temp.val[idx]; } } }; -// Common broadcast data loader. -template -struct BroadcastDataLoader { - __device__ __forceinline__ void operator()( - T args[Arity][VecSize], - const phi::Array &ins, - const phi::Array &configs, - const phi::Array &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { - uint32_t index_bc[Arity][VecSize]; -#pragma unroll - for (int j = 0; j < Arity; ++j) { -#pragma unroll - for (int k = 0; k < VecSize; ++k) { - index_bc[j][k] = 0; - args[j][k] = static_cast(1); - } - } - - uint32_t thread_offset = block_offset + threadIdx.x * VecSize; +template +struct BroadcastDataInit { + template + static __device__ __forceinline__ void Apply(ArgsT *args) { + using Type = std::tuple_element_t; #pragma unroll for (int k = 0; k < VecSize; ++k) { - uint32_t idx = thread_offset + k; - if (IsBoundary) { - if (idx == numel) break; - } - -#pragma unroll - for (int i = 0; i < phi::DDim::kMaxRank; ++i) { - if (i == configs[0].rank) break; - auto fast_divmoder = configs[0].divmoders[i].Divmod(idx); - idx = fast_divmoder.val[0]; -#pragma unroll - for (int j = 0; j < Arity; ++j) { - index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i]; - } - } + std::get(args[k]) = static_cast(1); } + } +}; +template +struct BroadcastDataSetter { + template + static __device__ __forceinline__ void Apply(const Array &ins, + ArgsT *args, + uint32_t index_bc[][VecSize]) { + using Type = std::tuple_element_t; #pragma unroll - for (int j = 0; j < Arity; ++j) { -#pragma unroll - for (int k = 0; k < VecSize; ++k) { - args[j][k] = ins[j][index_bc[j][k]]; - } + for (int k = 0; k < VecSize; ++k) { + std::get(args[k]) = + reinterpret_cast(ins[Index])[index_bc[Index][k]]; } } }; + #endif -template + typename Func, + bool IsBoundary, + int LoadType, + int VecSize, + int End, + int Begin = 0> +struct BcUnroller { + template + static HOSTDEVICE inline void step(Args &&...args) { + Func::Apply( + std::forward(args)...); + BcUnroller::step( + args...); + } +}; + +template