From 52ef86564a165f6957d295c3e10c0be994bec710 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Tue, 7 Jun 2022 10:09:10 +0800 Subject: [PATCH] [cherry-pick]Delete ElementwiseKernel in BroadcastKernel (#42779) (#43210) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete ElementwiseKernel in BroadcastKernel 减少所有Broadcast中重复功能调用,同时减少编译时间和问题体积 --- paddle/phi/kernels/funcs/broadcast_function.h | 22 +++-------- paddle/phi/kernels/gpu/bitwise_kernel.cu | 6 +-- paddle/phi/kernels/gpu/gelu_grad_kernel.cu | 10 +++-- paddle/phi/kernels/gpu/gelu_kernel.cu | 10 +++-- paddle/phi/kernels/gpu/reduce_grad.h | 21 ++++------- .../kernels/gpu/reduce_mean_grad_kernel.cu | 19 +++++++++- .../phi/kernels/gpu/reduce_sum_grad_kernel.cu | 37 +++++++++++++++++-- paddle/phi/kernels/gpu/where_kernel.cu | 3 +- 8 files changed, 80 insertions(+), 48 deletions(-) diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 2a4c46eb79..514ecddfe2 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -496,26 +496,16 @@ void BroadcastKernel(const KPDevice &ctx, Functor func) { std::vector dims_size; dims_size.reserve(ins.size()); - bool no_broadcast_flag = true; for (auto *in : ins) { - no_broadcast_flag &= ins[0]->dims() == in->dims(); dims_size.emplace_back(in->dims().size()); } - if (ins.size() > 0 && outs->size() > 0) { - no_broadcast_flag &= outs->at(0)->dims() == ins[0]->dims(); - } - - if (no_broadcast_flag) { - phi::funcs::ElementwiseKernel(ctx, ins, outs, func); - } else { - axis = axis == -1 - ? *std::max_element(dims_size.begin(), dims_size.end()) - - *std::min_element(dims_size.begin(), dims_size.end()) - : axis; - BroadcastKernelForDifferentVecSize( - ctx, ins, outs, axis, func); - } + axis = axis == -1 + ? *std::max_element(dims_size.begin(), dims_size.end()) - + *std::min_element(dims_size.begin(), dims_size.end()) + : axis; + BroadcastKernelForDifferentVecSize( + ctx, ins, outs, axis, func); } template diff --git a/paddle/phi/kernels/gpu/bitwise_kernel.cu b/paddle/phi/kernels/gpu/bitwise_kernel.cu index e88ecef318..dc189a7fd7 100644 --- a/paddle/phi/kernels/gpu/bitwise_kernel.cu +++ b/paddle/phi/kernels/gpu/bitwise_kernel.cu @@ -46,9 +46,9 @@ void BitwiseNotKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); std::vector ins = {&x}; std::vector outs = {out}; - funcs::BitwiseNotFunctor func; - funcs::BroadcastKernel( - dev_ctx, ins, &outs, -1, func); + funcs::BitwiseNotFunctor unary_func; + funcs::ElementwiseKernel>( + dev_ctx, ins, &outs, unary_func); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/gelu_grad_kernel.cu b/paddle/phi/kernels/gpu/gelu_grad_kernel.cu index 1e21f8d426..1f33d5c901 100644 --- a/paddle/phi/kernels/gpu/gelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gelu_grad_kernel.cu @@ -81,11 +81,13 @@ void GeluGradKernel(const Context& dev_ctx, } } #endif - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor()); + using Functor = GeluWithApproximateGradFunctor; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, Functor()); } else { - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor()); + using Functor = GeluWithoutApproximateGradFunctor; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, Functor()); } } diff --git a/paddle/phi/kernels/gpu/gelu_kernel.cu b/paddle/phi/kernels/gpu/gelu_kernel.cu index ce6dda2d6c..00dc58df0d 100644 --- a/paddle/phi/kernels/gpu/gelu_kernel.cu +++ b/paddle/phi/kernels/gpu/gelu_kernel.cu @@ -71,11 +71,13 @@ void GeluKernel(const Context& dev_ctx, } } #endif - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor()); + using Functor = GeluWithApproximateFunctor; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, Functor()); } else { - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor()); + using Functor = GeluWithoutApproximateFunctor; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, Functor()); } } diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index 1e39a08e9c..8656076497 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -43,22 +43,19 @@ void ReduceGrad(const GPUContext& dev_ctx, })); } -template class TransformOp> +template void ReduceGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, const std::vector& dims, bool keep_dim, bool reduce_all, - DenseTensor* x_grad) { + DenseTensor* x_grad, + Functor functor) { auto* in_x = &x; auto* d_out = &out_grad; auto* d_x = x_grad; - auto pt_out_dtype = x.dtype(); - // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = in_x->dims().size(); std::vector reduce_dims = @@ -79,14 +76,10 @@ void ReduceGradKernel(const Context& dev_ctx, auto pt_d_out = new_d_out; auto pt_d_x = *d_x; - using MPType = typename kps::details::MPTypeTrait::Type; - - phi::ReduceGrad>( - dev_ctx, - &pt_d_out, - &pt_d_x, - pt_out_dtype, - TransformOp(reduce_num)); + std::vector inputs = {&pt_d_out}; + std::vector outputs = {&pt_d_x}; + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, 0, functor); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index b81a5e50df..50564a339d 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -29,8 +29,23 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - ReduceGradKernel( - dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad); + int dim_size = x.dims().size(); + std::vector reduce_dims = + funcs::details::GetReduceDim(dims, dim_size, reduce_all); + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (x.dims())[i]; + } + using MPType = typename kps::details::MPTypeTrait::Type; + ReduceGradKernel>( + dev_ctx, + x, + out_grad, + dims, + keep_dim, + reduce_all, + x_grad, + kps::DivideFunctor(reduce_num)); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 1ad6b8fefe..8b111641cf 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -29,8 +29,40 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - ReduceGradKernel( - dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad); + using MPType = typename kps::details::MPTypeTrait::Type; + auto out_dtype = x.dtype(); + auto* in_x = &x; + auto* d_out = &out_grad; + auto* d_x = x_grad; + + // get reduce_dim and reduce_num for reduce_mean_grad + int dim_size = in_x->dims().size(); + std::vector reduce_dims = + funcs::details::GetReduceDim(dims, dim_size, reduce_all); + + auto update_dims = vectorize(d_x->dims()); + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (in_x->dims())[i]; + update_dims[i] = 1; + } + // make new tensor + DenseTensor new_d_out(d_out->dtype()); + new_d_out.ShareDataWith(*d_out); + new_d_out.Resize(phi::make_ddim(update_dims)); + + dev_ctx.Alloc(d_x, x.dtype()); + auto pt_out_dtype = x.dtype(); + auto pt_d_out = new_d_out; + auto pt_d_x = *d_x; + std::vector inputs = {&pt_d_out}; + std::vector outputs = {&pt_d_x}; + phi::ReduceGrad>( + dev_ctx, + &pt_d_out, + &pt_d_x, + pt_out_dtype, + kps::IdentityFunctor()); } } // namespace phi @@ -48,4 +80,3 @@ PD_REGISTER_KERNEL(sum_grad, int64_t, phi::dtype::complex, phi::dtype::complex) {} - diff --git a/paddle/phi/kernels/gpu/where_kernel.cu b/paddle/phi/kernels/gpu/where_kernel.cu index a0be388065..441be02b99 100644 --- a/paddle/phi/kernels/gpu/where_kernel.cu +++ b/paddle/phi/kernels/gpu/where_kernel.cu @@ -40,8 +40,7 @@ void WhereKernel(const Context& ctx, ctx.template Alloc(out); CondFunctor func; - funcs::BroadcastKernel( - ctx, ins, &outs, -1, func); + funcs::ElementwiseKernel, 1>(ctx, ins, &outs, func); } } // namespace phi -- GitLab