From a0f586bc626b3fddcc104e46e521e37bc7e4e302 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 21 Jan 2022 20:03:11 +0800 Subject: [PATCH] [PTen]Separate origin Kernel and add Kernel for C++ API (#39002) * add kernel for c++ api * fix compile bugs * fix kunlun compile bugs * perfect cmake * fix compile bugs when run ci-inference * fix compile bugs * add non-raw kernel for fluid op * fix compile bugs * fix compile bugs * fix unit test bug --- cmake/pten_kernel.cmake | 61 +++-- paddle/fluid/operators/cholesky_solve_op.h | 2 +- .../elementwise/elementwise_add_op.h | 2 +- .../elementwise/elementwise_div_op.h | 2 +- .../elementwise/elementwise_mul_op.cu | 4 +- .../elementwise/elementwise_mul_op.h | 2 +- .../operators/elementwise/elementwise_op.h | 24 +- .../elementwise/elementwise_sub_op.h | 2 +- paddle/fluid/operators/lu_op.h | 4 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 13 +- paddle/pten/api/include/kernel_signature.h | 6 - paddle/pten/core/kernel_alias_name.h | 12 +- paddle/pten/kernels/cpu/math_kernel.cc | 76 +++---- paddle/pten/kernels/gpu/math_kernel.cu | 77 ++++--- paddle/pten/kernels/math_kernel.cc | 212 ++++++++++++++++++ paddle/pten/kernels/math_kernel.h | 125 ++++++----- .../tests/kernels/test_elementwise_dev_api.cc | 12 +- python/paddle/utils/code_gen/api.yaml | 7 +- 18 files changed, 453 insertions(+), 190 deletions(-) create mode 100644 paddle/pten/kernels/math_kernel.cc diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index bc9fefb58f..c2928376a0 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -103,38 +103,55 @@ function(kernel_library TARGET) list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH xpu_srcs xpu_srcs_len) - if (${common_srcs_len} GREATER 0) - # If the kernel has a device independent public implementation, - # we will use this implementation and will not adopt the implementation - # under specific devices + # Build Target according different src organization + if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR + ${xpu_srcs_len} GREATER 0) AND ${common_srcs_len} GREATER 0) + # If the common_srcs depends on specific device srcs, build target using this rule. + if (WITH_GPU) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) + endif() + elseif (WITH_ROCM) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) + endif() + else() + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) + endif() + endif() + elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) if (WITH_GPU) - nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() elseif (WITH_ROCM) - hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() else() - cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() endif() else() - # If the kernel has a header file declaration, but no corresponding - # implementation can be found, this is not allowed - if (${cpu_srcs_len} EQUAL 0 AND ${gpu_srcs_len} EQUAL 0 AND - ${xpu_srcs_len} EQUAL 0) - message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") + if (${common_srcs_len} EQUAL 0) + message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") else() + # If the kernel has a device independent public implementation, + # we will use this implementation and will not adopt the implementation + # under specific devices if (WITH_GPU) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) elseif (WITH_ROCM) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) else() - if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) - cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() - endif() + endif() endif() if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR diff --git a/paddle/fluid/operators/cholesky_solve_op.h b/paddle/fluid/operators/cholesky_solve_op.h index 4b1d075de9..5004aad7c5 100644 --- a/paddle/fluid/operators/cholesky_solve_op.h +++ b/paddle/fluid/operators/cholesky_solve_op.h @@ -202,7 +202,7 @@ class CholeskySolveGradKernel : public framework::OpKernel { commonterm_for_range(commonterm_functor); commonterm_conj = helper.Transpose(commonterm_conj); - pten::AddKernel( + pten::AddRawKernel( static_cast::TYPE &>(dev_ctx), commonterm, commonterm_conj, -1, &commonterm); diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index a4897a06d5..5c4f791b22 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -61,7 +61,7 @@ class ElementwiseAddKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::AddKernel( + pten::AddRawKernel( static_cast::TYPE &>(dev_ctx), *pt_x.get(), *pt_y.get(), axis, pt_z.get()); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 44f695278d..a45f09b63e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -51,7 +51,7 @@ class ElementwiseDivKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::DivideKernel( + pten::DivideRawKernel( static_cast::TYPE&>(dev_ctx), *pt_x.get(), *pt_y.get(), axis, pt_z.get()); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 86a8031063..0c7d12ae0a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -51,8 +51,8 @@ class ElementwiseMulKernel auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y_lod); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod); - pten::MultiplyKernel(cuda_ctx, *pt_x.get(), *pt_y.get(), axis, - pt_z.get()); + pten::MultiplyRawKernel(cuda_ctx, *pt_x.get(), *pt_y.get(), axis, + pt_z.get()); } else { PADDLE_THROW(platform::errors::InvalidArgument( "X's type[%s] is not supported by elementwise_op. X's type should be " diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index d918407930..e7a5e48b1f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -124,7 +124,7 @@ class ElementwiseMulKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod); - pten::MultiplyKernel( + pten::MultiplyRawKernel( static_cast::TYPE&>(dev_ctx), *pt_x.get(), *pt_y.get(), axis, pt_z.get()); diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index e1d9655e29..aaf33ca674 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -140,26 +140,42 @@ class ElementwiseOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext &ctx) const override { + int axis = ctx.Attr("axis"); if (Type() == "elementwise_add") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("add", {"X", "Y"}, {"axis"}, {"Out"}); + if (axis == -1) { + return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"}); + } + return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"}, + {"Out"}); } } if (Type() == "elementwise_sub") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("subtract", {"X", "Y"}, {"axis"}, + if (axis == -1) { + return framework::KernelSignature("subtract", {"X", "Y"}, {}, + {"Out"}); + } + return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_div") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("divide", {"X", "Y"}, {"axis"}, + if (axis == -1) { + return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); + } + return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_mul") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("multiply", {"X", "Y"}, {"axis"}, + if (axis == -1) { + return framework::KernelSignature("multiply", {"X", "Y"}, {}, + {"Out"}); + } + return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 46d4a93e80..7d1749f20a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -51,7 +51,7 @@ class ElementwiseSubKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::SubtractKernel( + pten::SubtractRawKernel( static_cast::TYPE&>(dev_ctx), *pt_x.get(), *pt_y.get(), axis, pt_z.get()); diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index 6beef1add8..c3b3552ba1 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -221,7 +221,7 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1, out->Resize(src1.dims()); out->mutable_data(dev_ctx.GetPlace()); - pten::AddKernel< + pten::AddRawKernel< T, typename paddle::framework::ConvertToPtenContext::TYPE>( static_cast::TYPE&>(dev_ctx), @@ -234,7 +234,7 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1, out->Resize(src1.dims()); out->mutable_data(dev_ctx.GetPlace()); - pten::SubtractKernel< + pten::SubtractRawKernel< T, typename paddle::framework::ConvertToPtenContext::TYPE>( static_cast::TYPE&>(dev_ctx), diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index e2002856a4..2e5bd7a42b 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -551,17 +551,26 @@ class ReduceOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext& ctx) const override { + bool reduce_all = ctx.Attr("reduce_all"); if (Type() == "reduce_sum") { if (ctx.InputVar("X")->IsType()) { + if (!reduce_all) { + return framework::KernelSignature( + "sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"}); + } return framework::KernelSignature( - "sum", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, + "sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, {"Out"}); } } if (Type() == "reduce_mean") { if (ctx.InputVar("X")->IsType()) { + if (!reduce_all) { + return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, + {"Out"}); + } return framework::KernelSignature( - "mean", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); } } // TODO(chentianyu03): support other cases after selected rows added diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index e3929d5915..d750b47ef8 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -30,7 +30,6 @@ using DeviceContext = paddle::platform::DeviceContext; using add_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using cast_kernel = void (*)(const DeviceContext&, @@ -46,7 +45,6 @@ using concat_kernel = void (*)(const DeviceContext&, using divide_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using dot_kernel = void (*)(const DeviceContext&, @@ -82,13 +80,11 @@ using mean_kernel = void (*)(const DeviceContext&, const DenseTensor&, const std::vector&, bool, - bool, DenseTensor*); using multiply_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using reshape_kernel = void (*)(const DeviceContext&, @@ -107,14 +103,12 @@ using sum_kernel = void (*)(const DeviceContext&, const DenseTensor&, const std::vector&, bool, - bool, DataType, DenseTensor*); using subtract_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using conj_kernel = void (*)(const DeviceContext&, diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 5c86787966..8e089970f9 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -20,10 +20,10 @@ namespace pten { // the key is kernel_name in fluid, the value is the kernel_name in pten // the key is sorted by key's alphabet const std::unordered_map kernel_alias_name_map = { - {"elementwise_add", "add"}, - {"elementwise_div", "divide"}, - {"elementwise_mul", "muliply"}, - {"elementwise_sub", "subtract"}, + {"elementwise_add", "add_raw"}, + {"elementwise_div", "divide_raw"}, + {"elementwise_mul", "muliply_raw"}, + {"elementwise_sub", "subtract_raw"}, {"fill_any_like", "full_like"}, {"fill_constant", "full"}, {"flatten_contiguous_range", "flatten"}, @@ -32,8 +32,8 @@ const std::unordered_map kernel_alias_name_map = { {"matmul_v2_grad", "matmul_grad"}, {"matmul_v2_grad_grad", "matmul_double_grad"}, {"matmul_v2_triple_grad", "matmul_triple_grad"}, - {"reduce_mean", "mean"}, - {"reduce_sum", "sum"}, + {"reduce_mean", "mean_raw"}, + {"reduce_sum", "sum_raw"}, {"reshape2", "reshape"}, {"reshape2_grad", "reshape_grad"}, {"reshape2_grad_grad", "reshape_double_grad"}, diff --git a/paddle/pten/kernels/cpu/math_kernel.cc b/paddle/pten/kernels/cpu/math_kernel.cc index 7841dd4113..706a40936a 100644 --- a/paddle/pten/kernels/cpu/math_kernel.cc +++ b/paddle/pten/kernels/cpu/math_kernel.cc @@ -32,11 +32,11 @@ namespace pten { #define DEFINE_CPU_ELEMENTWISE_OP(name) \ template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - int axis, \ - DenseTensor* out) { \ + void name##RawKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out) { \ out->mutable_data(); \ if (x.dims() == y.dims()) { \ SameDimsElementwiseCompute>()( \ @@ -55,23 +55,35 @@ namespace pten { } template -void MeanKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { auto out_dtype = x.dtype(); pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } template -void DivideKernel(const Context& dev_ctx, +void SumRawKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, DenseTensor* out) { + pten::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +template +void DivideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { // allocate memory for out out->mutable_data(); if (x.dims() == y.dims() && std::is_floating_point::value) { @@ -90,18 +102,6 @@ void DivideKernel(const Context& dev_ctx, } } -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out) { - pten::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); -} - // Create the definition of Add DEFINE_CPU_ELEMENTWISE_OP(Add) @@ -118,42 +118,40 @@ using complex128 = ::paddle::platform::complex; // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; -PT_REGISTER_KERNEL( - mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} -PT_REGISTER_KERNEL(add, +PT_REGISTER_KERNEL(add_raw, CPU, ALL_LAYOUT, - pten::AddKernel, + pten::AddRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(subtract, +PT_REGISTER_KERNEL(subtract_raw, CPU, ALL_LAYOUT, - pten::SubtractKernel, + pten::SubtractRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(divide, +PT_REGISTER_KERNEL(divide_raw, CPU, ALL_LAYOUT, - pten::DivideKernel, + pten::DivideRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(multiply, +PT_REGISTER_KERNEL(multiply_raw, CPU, ALL_LAYOUT, - pten::MultiplyKernel, + pten::MultiplyRawKernel, float, double, int, @@ -161,10 +159,10 @@ PT_REGISTER_KERNEL(multiply, bool, complex64, complex128) {} -PT_REGISTER_KERNEL(sum, +PT_REGISTER_KERNEL(sum_raw, CPU, ALL_LAYOUT, - pten::SumKernel, + pten::SumRawKernel, bool, float, double, @@ -175,3 +173,5 @@ PT_REGISTER_KERNEL(sum, complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } +PT_REGISTER_KERNEL( + mean_raw, CPU, ALL_LAYOUT, pten::MeanRawKernel, float, double, bool) {} diff --git a/paddle/pten/kernels/gpu/math_kernel.cu b/paddle/pten/kernels/gpu/math_kernel.cu index d7a16ac49b..6b6383f810 100644 --- a/paddle/pten/kernels/gpu/math_kernel.cu +++ b/paddle/pten/kernels/gpu/math_kernel.cu @@ -37,11 +37,11 @@ namespace pten { #define DEFINE_CUDA_ELEMENTWISE_OP(name) \ template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - int axis, \ - DenseTensor* out) { \ + void name##RawKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out) { \ std::vector inputs; \ std::vector outputs; \ inputs.emplace_back(&x); \ @@ -57,17 +57,29 @@ namespace pten { */ template -void MeanKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { auto out_dtype = x.dtype(); pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } +template +void SumRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out) { + pten::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + // Create the definition of Add DEFINE_CUDA_ELEMENTWISE_OP(Add) // Create the definition of Subtract @@ -77,30 +89,16 @@ DEFINE_CUDA_ELEMENTWISE_OP(Multiply) // Create the definition of Divide DEFINE_CUDA_ELEMENTWISE_OP(Divide) -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out) { - pten::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); -} - } // namespace pten using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_KERNEL( - mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} -PT_REGISTER_KERNEL(add, +PT_REGISTER_KERNEL(add_raw, GPU, ALL_LAYOUT, - pten::AddKernel, + pten::AddRawKernel, float, double, int, @@ -108,10 +106,10 @@ PT_REGISTER_KERNEL(add, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(subtract, +PT_REGISTER_KERNEL(subtract_raw, GPU, ALL_LAYOUT, - pten::SubtractKernel, + pten::SubtractRawKernel, float, double, int, @@ -119,10 +117,10 @@ PT_REGISTER_KERNEL(subtract, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(divide, +PT_REGISTER_KERNEL(divide_raw, GPU, ALL_LAYOUT, - pten::DivideKernel, + pten::DivideRawKernel, float, double, int, @@ -130,10 +128,10 @@ PT_REGISTER_KERNEL(divide, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(multiply, +PT_REGISTER_KERNEL(multiply_raw, GPU, ALL_LAYOUT, - pten::MultiplyKernel, + pten::MultiplyRawKernel, float, double, int, @@ -142,10 +140,10 @@ PT_REGISTER_KERNEL(multiply, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(sum, +PT_REGISTER_KERNEL(sum_raw, GPU, ALL_LAYOUT, - pten::SumKernel, + pten::SumRawKernel, bool, float, double, @@ -156,3 +154,12 @@ PT_REGISTER_KERNEL(sum, complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } + +PT_REGISTER_KERNEL(mean_raw, + GPU, + ALL_LAYOUT, + pten::MeanRawKernel, + float, + double, + bool, + float16) {} diff --git a/paddle/pten/kernels/math_kernel.cc b/paddle/pten/kernels/math_kernel.cc new file mode 100644 index 0000000000..423282ab97 --- /dev/null +++ b/paddle/pten/kernels/math_kernel.cc @@ -0,0 +1,212 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/math_kernel.h" + +#include "paddle/pten/backends/all_context.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { + +template +void MeanKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out) { + bool reduce_all = false; + MeanRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); +} + +template +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DataType out_dtype, + DenseTensor* out) { + bool reduce_all = false; + SumRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); +} + +template +void AddKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + AddRawKernel(dev_ctx, x, y, axis, out); +} + +template +void SubtractKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + SubtractRawKernel(dev_ctx, x, y, axis, out); +} + +template +void DivideKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + DivideRawKernel(dev_ctx, x, y, axis, out); +} + +template +void MultiplyKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + MultiplyRawKernel(dev_ctx, x, y, axis, out); +} + +} // namespace pten + +using complex64 = ::paddle::platform::complex; +using complex128 = ::paddle::platform::complex; + +PT_REGISTER_KERNEL( + mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} + +PT_REGISTER_KERNEL(sum, + CPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + paddle::platform::float16, + int, + int64_t, + complex64, + complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} + +PT_REGISTER_KERNEL(add, + CPU, + ALL_LAYOUT, + pten::AddKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(subtract, + CPU, + ALL_LAYOUT, + pten::SubtractKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(divide, + CPU, + ALL_LAYOUT, + pten::DivideKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(multiply, + CPU, + ALL_LAYOUT, + pten::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + complex64, + complex128) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_KERNEL(mean, + GPU, + ALL_LAYOUT, + pten::MeanKernel, + float, + double, + bool, + paddle::platform::float16) {} +PT_REGISTER_KERNEL(sum, + GPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + paddle::platform::float16, + int, + int64_t, + complex64, + complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} +PT_REGISTER_KERNEL(add, + GPU, + ALL_LAYOUT, + pten::AddKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(subtract, + GPU, + ALL_LAYOUT, + pten::SubtractKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(divide, + GPU, + ALL_LAYOUT, + pten::DivideKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(multiply, + GPU, + ALL_LAYOUT, + pten::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + paddle::platform::float16, + complex64, + complex128) {} +#endif diff --git a/paddle/pten/kernels/math_kernel.h b/paddle/pten/kernels/math_kernel.h index 65c0f84e69..95379baaf3 100644 --- a/paddle/pten/kernels/math_kernel.h +++ b/paddle/pten/kernels/math_kernel.h @@ -22,104 +22,127 @@ limitations under the License. */ namespace pten { +template +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out); + template void MeanKernel(const Context& dev_ctx, const DenseTensor& x, const std::vector& dims, bool keep_dim, - bool reduce_all, DenseTensor* out); +template +void SumRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out); + +template +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DataType out_dtype, + DenseTensor* out); + +template +void AddRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void AddKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void SubtractRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void SubtractKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void DivideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void DivideKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void MultiplyRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void MultiplyKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out); - template DenseTensor Add(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - AddKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + AddKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Subtract(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - SubtractKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + SubtractKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Divide(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - DivideKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + DivideKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Multiply(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - MultiplyKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + MultiplyKernel(dev_ctx, x, y, &dense_out); return dense_out; } @@ -130,8 +153,7 @@ DenseTensor Mean(const Context& dev_ctx, bool keep_dim) { auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); - bool reduce_all = false; - MeanKernel(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); + MeanKernel(dev_ctx, x, axis, keep_dim, &dense_out); return dense_out; } @@ -144,12 +166,7 @@ DenseTensor Sum(const Context& dev_ctx, auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); - // The real value of reduce_all will be get in kernel - // so use default value(false) is OK. - bool reduce_all = false; - - SumKernel( - dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); + SumKernel(dev_ctx, x, axis, keep_dim, dtype, &dense_out); return dense_out; } diff --git a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc index 0bc16371c0..e5d9b05eec 100644 --- a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc +++ b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc @@ -54,11 +54,10 @@ TEST(DEV_API, add) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; // 2. test API pten::CPUContext dev_ctx; - auto dense_out = pten::Add(dev_ctx, dense_x, dense_y, axis); + auto dense_out = pten::Add(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -101,11 +100,10 @@ TEST(DEV_API, subtract) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; // 2. test API pten::CPUContext dev_ctx; - auto dense_out = pten::Subtract(dev_ctx, dense_x, dense_y, axis); + auto dense_out = pten::Subtract(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -148,11 +146,10 @@ TEST(DEV_API, divide) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0 + 1; } - int axis = 1; // 2. test API pten::CPUContext dev_ctx; - auto dense_out = pten::Divide(dev_ctx, dense_x, dense_y, axis); + auto dense_out = pten::Divide(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -195,11 +192,10 @@ TEST(DEV_API, multiply) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; // 2. test API pten::CPUContext dev_ctx; - auto dense_out = pten::Multiply(dev_ctx, dense_x, dense_y, axis); + auto dense_out = pten::Multiply(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 1bf5344e83..a0d7ce84f7 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -6,7 +6,6 @@ param : [x, y, -1] kernel : func : add - param : [x, y, -1] - api : cast args : (const Tensor& x, DataType out_dtype) @@ -44,7 +43,6 @@ param : [x, y, -1] kernel : func : divide - param : [x, y, -1] - api : dot args : (const Tensor& x, const Tensor& y) @@ -130,7 +128,6 @@ param: [x, axis, keep_dim] kernel : func : mean - param : [x, axis, keep_dim, false] - api : multiply args : (const Tensor& x, const Tensor& y) @@ -140,7 +137,6 @@ param : [x, y, -1] kernel : func : multiply - param : [x, y, -1] - api : ones_like args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) @@ -172,7 +168,6 @@ param : [x, y, -1] kernel : func : subtract - param : [x, y, -1] - api : sum args : (const Tensor& x, const std::vector& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false) @@ -182,7 +177,7 @@ param: [x, axis, keep_dim, dtype] kernel : func : sum - param : [x, axis, keep_dim, false, DataType::UNDEFINED] + param : [x, axis, keep_dim, dtype] data_type : x - api : zeros_like -- GitLab