diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index eb76eee104889042e470e65414a011afd0420d0f..160617695338a9f2e140b7b418c93ef0d7c57e17 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -36,9 +36,9 @@ void TensorReduceImpl(const platform::CUDADeviceContext& dev_ctx, gpuStream_t stream) { y->mutable_data(x.place()); - phi::funcs::TensorReduceImpl( + phi::funcs::ReduceKernel( static_cast(dev_ctx), x, y, transform, - origin_reduce_dims, stream); + origin_reduce_dims); } } // namespace operators diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cu b/paddle/phi/kernels/funcs/matrix_reduce.cu index 5e288c6e9c21703471ba7b6a6014510ba845ebd8..5c3ebd6bb01671eab670b477c0d97a962b2eaea0 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cu +++ b/paddle/phi/kernels/funcs/matrix_reduce.cu @@ -45,13 +45,8 @@ class MatrixReduceSumFunctor { out_reduce_dims.push_back(idx); } } - TensorReduceImpl>( - dev_ctx, - in, - out, - kps::IdentityFunctor(), - out_reduce_dims, - dev_ctx.stream()); + ReduceKernel>( + dev_ctx, in, out, kps::IdentityFunctor(), out_reduce_dims); } }; diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index ce6bb0d559c8143ffa443238043541bec987e0d3..5834f091d9a4de02afe7488ededc0189ae6f21d0 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -1087,12 +1087,12 @@ template class ReduceOp, typename TransformOp> -void TensorReduceImpl(const phi::GPUContext& dev_ctx, - const phi::DenseTensor& x, - phi::DenseTensor* y, - const TransformOp& transform, - const std::vector& origin_reduce_dims, - KPStream stream) { +void ReduceKernel(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& x, + phi::DenseTensor* y, + const TransformOp& transform, + const std::vector& origin_reduce_dims) { + auto stream = dev_ctx.stream(); dev_ctx.Alloc(y); auto x_dim = phi::vectorize(x.dims()); diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu index 926dffc7450dc6db0ff9d2384e92a9ece374026c..d4850b74477d29e868698e000fdc01e708d172b5 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu @@ -87,13 +87,12 @@ void BroadcastTensorsGradKernel(const Context& ctx, *input_tensor, ctx.GetPlace(), ctx, output_tensor); } else { // reduce_sum implementation on CUDA - funcs::TensorReduceImpl>( + funcs::ReduceKernel>( ctx, *input_tensor, output_tensor, kps::IdentityFunctor(), - reduce_dims_vec, - ctx.stream()); + reduce_dims_vec); } } } diff --git a/paddle/phi/kernels/gpu/compare_kernel.cu b/paddle/phi/kernels/gpu/compare_kernel.cu index 9c02627e5463b125076d86daef1b52fe4502a7e0..225164687b75ca88f0c0783d6bdabea227c076ae 100644 --- a/paddle/phi/kernels/gpu/compare_kernel.cu +++ b/paddle/phi/kernels/gpu/compare_kernel.cu @@ -80,8 +80,8 @@ inline void CompareAllKernelImpl(const Context& ctx, for (int i = 0; i < reduce_dims.size(); ++i) { reduce_dims[i] = i; } - funcs::TensorReduceImpl>( - ctx, tmp, out, kps::IdentityFunctor(), reduce_dims, ctx.stream()); + funcs::ReduceKernel>( + ctx, tmp, out, kps::IdentityFunctor(), reduce_dims); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad.h b/paddle/phi/kernels/gpu/elementwise_grad.h index b356f19555fc426dd5ce184ebe0f5ff39213aa8e..98df65c92f34c1704e7b89badb8bfa22ac16fa86 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad.h +++ b/paddle/phi/kernels/gpu/elementwise_grad.h @@ -29,13 +29,8 @@ void ReduceWrapper(const GPUContext &dev_ctx, DenseTensor *dst) { std::vector reduce_dims = funcs::GetReduceDim(dst->dims(), src->dims(), axis); - funcs::TensorReduceImpl>( - dev_ctx, - *src, - dst, - kps::IdentityFunctor(), - reduce_dims, - dev_ctx.stream()); + funcs::ReduceKernel>( + dev_ctx, *src, dst, kps::IdentityFunctor(), reduce_dims); } template @@ -172,9 +167,8 @@ void DefaultElementwiseAddGrad(const GPUContext &ctx, } std::vector reduce_dims = funcs::GetReduceDim(x.dims(), out.dims(), axis); - gpuStream_t stream = ctx.stream(); - funcs::TensorReduceImpl>( - ctx, dout, dx, kps::IdentityFunctor(), reduce_dims, stream); + funcs::ReduceKernel>( + ctx, dout, dx, kps::IdentityFunctor(), reduce_dims); } } // dy @@ -187,9 +181,8 @@ void DefaultElementwiseAddGrad(const GPUContext &ctx, } else { std::vector reduce_dims = funcs::GetReduceDim(y.dims(), out.dims(), axis); - gpuStream_t stream = ctx.stream(); - funcs::TensorReduceImpl>( - ctx, dout, dy, kps::IdentityFunctor(), reduce_dims, stream); + funcs::ReduceKernel>( + ctx, dout, dy, kps::IdentityFunctor(), reduce_dims); } } } @@ -285,9 +278,8 @@ void default_elementwise_sub_grad(const GPUContext &ctx, } std::vector reduce_dims = funcs::GetReduceDim(x.dims(), out.dims(), axis); - gpuStream_t stream = ctx.stream(); - funcs::TensorReduceImpl>( - ctx, dout, dx, kps::IdentityFunctor(), reduce_dims, stream); + funcs::ReduceKernel>( + ctx, dout, dx, kps::IdentityFunctor(), reduce_dims); } } // dy @@ -306,9 +298,8 @@ void default_elementwise_sub_grad(const GPUContext &ctx, } else { std::vector reduce_dims = funcs::GetReduceDim(y.dims(), out.dims(), axis); - gpuStream_t stream = ctx.stream(); - funcs::TensorReduceImpl>( - ctx, dout, dy, kps::InverseFunctor(), reduce_dims, stream); + funcs::ReduceKernel>( + ctx, dout, dy, kps::InverseFunctor(), reduce_dims); } } } diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index 0319de7558e824926e37469a3c222c1f9a9673fc..da5315f34479f92bfb0e5d807e28882eafa3d2ac 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -39,8 +39,6 @@ void Reduce(const KPDevice& dev_ctx, reduce_num *= (x.dims())[i]; } - KPStream stream = dev_ctx.stream(); - if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) { auto tmp_tensor = phi::Cast(dev_ctx, x, out_dtype); PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( @@ -48,29 +46,23 @@ void Reduce(const KPDevice& dev_ctx, phi::DataType::INT64, phi::DataType::FLOAT16, out_dtype, - "TensorReduceImpl", + "ReduceKernel", ([&] { using MPType = typename kps::details::MPTypeTrait::Type; - phi::funcs::TensorReduceImpl>( + phi::funcs::ReduceKernel>( dev_ctx, tmp_tensor, out, TransformOp(reduce_num), - reduce_dims, - stream); + reduce_dims); })); } else { using MPType = typename kps::details::MPTypeTrait::Type; - phi::funcs::TensorReduceImpl>( - dev_ctx, - x, - out, - TransformOp(reduce_num), - reduce_dims, - stream); + phi::funcs::ReduceKernel>( + dev_ctx, x, out, TransformOp(reduce_num), reduce_dims); } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_grad_kernel.cu b/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_grad_kernel.cu index 598b0138fb3a11d63230b9b40071fe7972c48da5..6fc65006ae264e29a5637cb18a95b5785c922f07 100644 --- a/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_grad_kernel.cu @@ -69,17 +69,12 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx, dev_ctx.template Alloc(counts_tensor); counts_tensor->Resize(in_grad->dims()); - int limit = in_grad->numel(); - int blocks = NumBlocks(limit); - int threads = kNumCUDAThreads; std::vector ins = {&x, &label, &out_grad}; std::vector outs = {in_grad, counts_tensor}; auto functor = SigmoidBwdFunctor(ignore_index); - constexpr int Size = 2; - phi::funcs::ElementwiseKernel( + phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, functor); if (normalize) { - T *counts = dev_ctx.template Alloc(counts_tensor); DenseTensor *norm_tensor = new DenseTensor(); norm_tensor->Resize({sizeof(T)}); dev_ctx.template Alloc(norm_tensor); @@ -89,13 +84,8 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx, reduce_dim.push_back(i); } - funcs::TensorReduceImpl>( - dev_ctx, - *counts_tensor, - norm_tensor, - NonzeroFunctor(), - reduce_dim, - dev_ctx.stream()); + funcs::ReduceKernel>( + dev_ctx, *counts_tensor, norm_tensor, NonzeroFunctor(), reduce_dim); T *norm = dev_ctx.template Alloc(norm_tensor); auto norm_cpu_mem = paddle::memory::Alloc(phi::CPUPlace(), sizeof(T)); T *norm_cpu_ptr = reinterpret_cast(norm_cpu_mem->ptr()); @@ -114,6 +104,7 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx, phi::funcs::ElementwiseKernel(dev_ctx, div_ins, &div_outs, div_functor); delete norm_tensor; } + delete counts_tensor; } } // namespace phi diff --git a/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_kernel.cu b/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_kernel.cu index 13d63f8d97e42cf55c4e2acc1d3f7bf92efef868..4b6e5628c72af438fc98cf6f64c5a36adcd6d2ee 100644 --- a/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_kernel.cu +++ b/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_kernel.cu @@ -69,17 +69,12 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx, dev_ctx.template Alloc(counts_tensor); counts_tensor->Resize(out->dims()); - int limit = out->numel(); - int blocks = NumBlocks(limit); - int threads = kNumCUDAThreads; std::vector ins = {&x, &label}; std::vector outs = {out, counts_tensor}; auto functor = SigmoidFwdFunctor(ignore_index); - constexpr int Size = 2; - phi::funcs::ElementwiseKernel( + phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, functor); if (normalize) { - T *counts = dev_ctx.template Alloc(counts_tensor); DenseTensor *norm_tensor = new DenseTensor(); norm_tensor->Resize({sizeof(T)}); dev_ctx.template Alloc(norm_tensor); @@ -89,13 +84,8 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx, reduce_dim.push_back(i); } - funcs::TensorReduceImpl>( - dev_ctx, - *counts_tensor, - norm_tensor, - NonzeroFunctor(), - reduce_dim, - dev_ctx.stream()); + funcs::ReduceKernel>( + dev_ctx, *counts_tensor, norm_tensor, NonzeroFunctor(), reduce_dim); T *norm = dev_ctx.template Alloc(norm_tensor); auto norm_cpu_mem = paddle::memory::Alloc(phi::CPUPlace(), sizeof(T)); T *norm_cpu_ptr = reinterpret_cast(norm_cpu_mem->ptr()); @@ -114,8 +104,8 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx, phi::funcs::ElementwiseKernel(dev_ctx, div_ins, &div_outs, div_functor); delete norm_tensor; - delete counts_tensor; } + delete counts_tensor; } } // namespace phi diff --git a/paddle/phi/kernels/gpu/trace_kernel.cu b/paddle/phi/kernels/gpu/trace_kernel.cu index 4266f0174ff6c17b2146576544ff090c7a272872..4a749c5b3347da24c8aba35d33673801b4b7e407 100644 --- a/paddle/phi/kernels/gpu/trace_kernel.cu +++ b/paddle/phi/kernels/gpu/trace_kernel.cu @@ -31,11 +31,10 @@ void TraceKernel(const Context& ctx, T* out_data = ctx.template Alloc(out); auto diag = funcs::Diagonal(ctx, &x, offset, axis1, axis2); if (diag.numel() > 0) { - auto stream = ctx.stream(); std::vector reduce_dims; reduce_dims.push_back(out->dims().size()); - funcs::TensorReduceImpl>( - ctx, diag, out, kps::IdentityFunctor(), reduce_dims, stream); + funcs::ReduceKernel>( + ctx, diag, out, kps::IdentityFunctor(), reduce_dims); } else { phi::funcs::SetConstant functor; functor(ctx, out, static_cast(0)); diff --git a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h index d06bdc55030567ae4de8ba51bec7282231cc8661..495b93f2a4ef0f790d53605e4531af7040c6b2ad 100644 --- a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h @@ -59,9 +59,8 @@ struct ReduceSumForMatmulGrad { const DenseTensor& input, DenseTensor* output, const std::vector& reduce_dims) { - auto stream = dev_ctx.stream(); - funcs::TensorReduceImpl>( - dev_ctx, input, output, kps::IdentityFunctor(), reduce_dims, stream); + funcs::ReduceKernel>( + dev_ctx, input, output, kps::IdentityFunctor(), reduce_dims); } }; #endif