From 7a4a512daa172062068c7fab669bd321f1926274 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Wed, 5 Jan 2022 14:33:46 +0800 Subject: [PATCH] [pten]Move reduce code new (#38648) * change 'math' to 'math_kernel' * fix compile bugs * merge develop * fix compile bugs * fix compile bugs * move reduce files by new rule * add set header * format code style * merge develop and fix conflict * merge develop and fix conflict Co-authored-by: YuanRisheng --- .../reduce_ops/check_reduce_rank_test.cu | 2 +- .../fluid/operators/reduce_ops/reduce_op.cu.h | 2 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 6 +- paddle/pten/include/math.h | 4 +- paddle/pten/kernels/cpu/math_kernel.cc | 41 +++++------ .../general/reduce_impl.h => cpu/reduce.h} | 9 ++- paddle/pten/kernels/gpu/math_kernel.cu | 33 +++++---- .../reduce_cuda_impl.h => gpu/reduce.h} | 48 +++++++++++++ .../pten/kernels/hybird/cuda/reduce/reduce.h | 71 ------------------- paddle/pten/kernels/math_kernel.h | 26 +++---- 10 files changed, 107 insertions(+), 135 deletions(-) rename paddle/pten/kernels/{hybird/general/reduce_impl.h => cpu/reduce.h} (95%) rename paddle/pten/kernels/{hybird/cuda/reduce/reduce_cuda_impl.h => gpu/reduce.h} (96%) delete mode 100644 paddle/pten/kernels/hybird/cuda/reduce/reduce.h diff --git a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu index 63d42790205..33e195f8992 100644 --- a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu +++ b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "gtest/gtest.h" -#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h" +#include "paddle/pten/kernels/gpu/reduce.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index e779da641b9..62486f62f66 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -32,7 +32,7 @@ namespace cub = hipcub; #include "paddle/fluid/framework/tensor.h" #include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h" +#include "paddle/pten/kernels/gpu/reduce.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index bd09a7951aa..e1854d8a13d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -28,10 +28,10 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" #include "paddle/pten/include/math.h" -#include "paddle/pten/kernels/hybird/general/reduce_impl.h" +#include "paddle/pten/kernels/cpu/reduce.h" #if defined(__HIPCC__) || defined(__NVCC__) -#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h" +#include "paddle/pten/kernels/gpu/reduce.h" #endif namespace paddle { @@ -259,7 +259,7 @@ class ReduceKernel : public framework::OpKernel { std::vector tmp_dims(dims.begin(), dims.end()); // call new kernel - pten::general::Reduce( + pten::Reduce( dev_ctx, *pt_x.get(), reduce_all, tmp_dims, keep_dim, pten::TransToPtenDataType(cast_out_dtype), pt_out.get()); } diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index 9abfa297a94..e46f460260a 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -45,7 +45,7 @@ DenseTensor Mean(const ContextT& dev_ctx, dev_ctx.GetPlace()), std::move(out_meta)); bool reduce_all = false; - Mean(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); + MeanKernel(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); return dense_out; } @@ -65,7 +65,7 @@ DenseTensor Sum(const ContextT& dev_ctx, // so use default value(false) is OK. bool reduce_all = false; - Sum( + SumKernel( dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); return dense_out; } diff --git a/paddle/pten/kernels/cpu/math_kernel.cc b/paddle/pten/kernels/cpu/math_kernel.cc index c022dd08bbe..4f895d9514a 100644 --- a/paddle/pten/kernels/cpu/math_kernel.cc +++ b/paddle/pten/kernels/cpu/math_kernel.cc @@ -18,13 +18,10 @@ #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/common/scalar.h" #include "paddle/pten/core/kernel_registry.h" - #include "paddle/pten/kernels/cpu/elementwise_impl.h" +#include "paddle/pten/kernels/cpu/reduce.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h" -#include "paddle/pten/kernels/hybird/eigen/reduce.h" -#include "paddle/pten/kernels/hybird/general/reduce_impl.h" - // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/bfloat16.h" @@ -57,14 +54,14 @@ namespace pten { } template -void Mean(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void MeanKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { auto out_dtype = x.dtype(); - pten::general::Reduce( + pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -93,14 +90,14 @@ void DivideKernel(const Context& dev_ctx, } template -void Sum(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out) { - pten::general::Reduce( +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out) { + pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -120,8 +117,8 @@ using complex128 = ::paddle::platform::complex; // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; -PT_REGISTER_CTX_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) { -} +PT_REGISTER_CTX_KERNEL( + mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} PT_REGISTER_CTX_KERNEL(add, CPU, ALL_LAYOUT, @@ -166,7 +163,7 @@ PT_REGISTER_CTX_KERNEL(multiply, PT_REGISTER_CTX_KERNEL(sum, CPU, ALL_LAYOUT, - pten::Sum, + pten::SumKernel, bool, float, double, diff --git a/paddle/pten/kernels/hybird/general/reduce_impl.h b/paddle/pten/kernels/cpu/reduce.h similarity index 95% rename from paddle/pten/kernels/hybird/general/reduce_impl.h rename to paddle/pten/kernels/cpu/reduce.h index 631ad7f6125..fc5dbe9d58d 100644 --- a/paddle/pten/kernels/hybird/general/reduce_impl.h +++ b/paddle/pten/kernels/cpu/reduce.h @@ -13,14 +13,15 @@ // limitations under the License. #pragma once -#include "paddle/fluid/platform/transform.h" + +#include + #include "paddle/pten/api/ext/dispatch.h" -#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/cast_kernel.h" #include "paddle/pten/kernels/hybird/eigen/reduce.h" namespace pten { -namespace general { template void Reduce(const DeviceContext& dev_ctx, @@ -71,6 +72,4 @@ void Reduce(const DeviceContext& dev_ctx, } } -} // namespace general - } // namespace pten diff --git a/paddle/pten/kernels/gpu/math_kernel.cu b/paddle/pten/kernels/gpu/math_kernel.cu index 760bebe6878..051f7cb3bdd 100644 --- a/paddle/pten/kernels/gpu/math_kernel.cu +++ b/paddle/pten/kernels/gpu/math_kernel.cu @@ -16,9 +16,8 @@ limitations under the License. */ #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h" +#include "paddle/pten/kernels/gpu/reduce.h" #include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h" -#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h" -#include "paddle/pten/kernels/hybird/general/reduce_impl.h" #ifdef __NVCC__ #include "cub/cub.cuh" @@ -76,12 +75,12 @@ struct DivideFunctor { */ template -void Mean(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void MeanKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { auto out_dtype = x.dtype(); pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); @@ -97,13 +96,13 @@ DEFINE_CUDA_ELEMENTWISE_OP(Multiply) DEFINE_CUDA_ELEMENTWISE_OP(Divide) template -void Sum(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out) { +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out) { pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -115,7 +114,7 @@ using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; PT_REGISTER_CTX_KERNEL( - mean, GPU, ALL_LAYOUT, pten::Mean, float, double, bool, float16) {} + mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} PT_REGISTER_CTX_KERNEL(add, GPU, ALL_LAYOUT, @@ -164,7 +163,7 @@ PT_REGISTER_CTX_KERNEL(multiply, PT_REGISTER_CTX_KERNEL(sum, GPU, ALL_LAYOUT, - pten::Sum, + pten::SumKernel, bool, float, double, diff --git a/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h b/paddle/pten/kernels/gpu/reduce.h similarity index 96% rename from paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h rename to paddle/pten/kernels/gpu/reduce.h index 4cfcad9149a..0704b76a2f0 100644 --- a/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h +++ b/paddle/pten/kernels/gpu/reduce.h @@ -14,6 +14,9 @@ #pragma once +// CUDA and HIP use same api +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + #include #include #include @@ -40,6 +43,7 @@ namespace cub = hipcub; #include "paddle/fluid/string/string_helper.h" #include "paddle/pten/api/ext/dispatch.h" +#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/kernels/cast_kernel.h" #include "paddle/pten/kernels/copy_kernel.h" @@ -1230,4 +1234,48 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, } } // namespace kernels + +template class ReduceOp, + template class TransformOp> +void Reduce(const GPUContext& dev_ctx, + const DenseTensor& x, + bool reduce_all, + const std::vector& dims, + bool keep_dim, + DataType out_dtype, + DenseTensor* out) { + std::vector reduce_dims = + pten::kernels::details::GetReduceDim(dims, x.dims().size(), reduce_all); + + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (x.dims())[i]; + } + + gpuStream_t stream = dev_ctx.stream(); + + if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) { + PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( + pten::DataType::INT32, + pten::DataType::INT64, + out_dtype, + "TensorReduceFunctorImpl", + ([&] { + using MPType = typename kps::details::MPTypeTrait::Type; + pten::kernels::TensorReduceFunctorImpl>( + x, out, TransformOp(reduce_num), reduce_dims, stream); + })); + } else { + using MPType = typename kps::details::MPTypeTrait::Type; + pten::kernels:: + TensorReduceFunctorImpl>( + x, out, TransformOp(reduce_num), reduce_dims, stream); + } +} } // namespace pten + +#endif diff --git a/paddle/pten/kernels/hybird/cuda/reduce/reduce.h b/paddle/pten/kernels/hybird/cuda/reduce/reduce.h deleted file mode 100644 index 2281cd5ef78..00000000000 --- a/paddle/pten/kernels/hybird/cuda/reduce/reduce.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -// CUDA and HIP use same api -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - -#include "paddle/pten/api/ext/dispatch.h" -#include "paddle/pten/backends/gpu/gpu_context.h" -#include "paddle/pten/common/scalar.h" -#include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h" -namespace pten { - -template class ReduceOp, - template class TransformOp> -void Reduce(const GPUContext& dev_ctx, - const DenseTensor& x, - bool reduce_all, - const std::vector& dims, - bool keep_dim, - DataType out_dtype, - DenseTensor* out) { - std::vector reduce_dims = - pten::kernels::details::GetReduceDim(dims, x.dims().size(), reduce_all); - - int reduce_num = 1; - for (auto i : reduce_dims) { - reduce_num *= (x.dims())[i]; - } - - gpuStream_t stream = dev_ctx.stream(); - - if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) { - PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( - pten::DataType::INT32, - pten::DataType::INT64, - out_dtype, - "TensorReduceFunctorImpl", - ([&] { - using MPType = typename kps::details::MPTypeTrait::Type; - pten::kernels::TensorReduceFunctorImpl>( - x, out, TransformOp(reduce_num), reduce_dims, stream); - })); - } else { - using MPType = typename kps::details::MPTypeTrait::Type; - pten::kernels:: - TensorReduceFunctorImpl>( - x, out, TransformOp(reduce_num), reduce_dims, stream); - } -} - -} // namespace pten - -#endif diff --git a/paddle/pten/kernels/math_kernel.h b/paddle/pten/kernels/math_kernel.h index 2968aa3524a..b1e5188f3aa 100644 --- a/paddle/pten/kernels/math_kernel.h +++ b/paddle/pten/kernels/math_kernel.h @@ -21,12 +21,12 @@ limitations under the License. */ namespace pten { template -void Mean(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out); +void MeanKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out); template void AddKernel(const Context& dev_ctx, @@ -57,13 +57,13 @@ void MultiplyKernel(const Context& dev_ctx, DenseTensor* out); template -void Sum(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out); +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out); template DenseTensor Add(const ContextT& dev_ctx, -- GitLab