From 32d9beef869311dd1b25b1fa06dd5527f9365b62 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 23 Nov 2021 11:23:13 +0800 Subject: [PATCH] [PTen]Elementwise_div Kernel Refactor (#37418) * elementwise_div refactor * fix compile bugs in windows ci --- .../elementwise/elementwise_div_op.cc | 25 -------- .../elementwise/elementwise_div_op.cu | 16 ----- .../elementwise/elementwise_div_op.h | 27 ++++----- .../operators/elementwise/elementwise_op.h | 6 ++ paddle/pten/api/include/math.h | 2 + paddle/pten/api/lib/math.cc | 35 +++++++++++ paddle/pten/include/math.h | 21 +++++-- paddle/pten/kernels/cpu/math.cc | 34 +++++++++++ paddle/pten/kernels/cpu/math.h | 6 ++ paddle/pten/kernels/cuda/math.cu | 28 +++++++++ paddle/pten/kernels/cuda/math.h | 7 +++ .../pten/kernels/functions/blas/elementwise.h | 9 +++ .../functions/general/elementwise_functor.h | 60 +++++++++++++++++++ paddle/pten/tests/api/test_elementwise_api.cc | 54 +++++++++++++++++ .../tests/kernels/test_elementwise_dev_api.cc | 57 +++++++++++++++++- 15 files changed, 326 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 9a899ec11b4..38cd232e4d1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -22,31 +22,6 @@ limitations under the License. */ namespace paddle { namespace operators { -template -struct SameDimsElemwiseDiv< - platform::CPUDeviceContext, T, - typename std::enable_if::value>::type> { - void operator()(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - framework::Tensor *z) { - auto blas = math::GetBlas(ctx); - blas.VDIV(x->numel(), x->data(), y->data(), z->data()); - } -}; - -// use default div function for int32/int64 type because of divison zero -// checking. -template -struct SameDimsElemwiseDiv< - platform::CPUDeviceContext, T, - typename std::enable_if::value>::type> { - void operator()(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - framework::Tensor *z) { - default_elementwise_div(ctx, x, y, z); - } -}; - class ElementwiseDivOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Div"; } diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index ce487f284d9..80089243f25 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -23,22 +23,6 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -template -class ElementwiseDivKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - std::vector ins; - std::vector outs; - const auto& cuda_ctx = - ctx.template device_context(); - - int axis = PackTensorsIntoVector(ctx, &ins, &outs); - LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, DivFunctor()); - } -}; - template static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, const T* out, diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 0ec42e54e14..374dda9e83d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -23,6 +23,12 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/fluid/framework/pten_utils.h" + +// only can include the headers in paddle/pten/include dirs +#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/include/core.h" +#include "paddle/pten/include/math.h" namespace paddle { namespace operators { @@ -42,13 +48,6 @@ void default_elementwise_div(const framework::ExecutionContext& ctx, } } -template -struct SameDimsElemwiseDiv { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z); -}; - template class ElementwiseDivKernel : public framework::OpKernel { public: @@ -58,13 +57,13 @@ class ElementwiseDivKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); - auto dims_equal = x->dims() == y->dims(); - if (dims_equal) { - SameDimsElemwiseDiv same_dims_div; - same_dims_div(ctx, x, y, z); - } else { - default_elementwise_div(ctx, x, y, z); - } + auto& dev_ctx = ctx.device_context(); + int axis = ctx.Attr("axis"); + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); + pten::ElementwiseDiv(dev_ctx, *pt_x.get(), *pt_y.get(), axis, + pt_z.get()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 0b4865b4e87..be4c25ef4c5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -154,6 +154,12 @@ class ElementwiseOp : public framework::OperatorWithKernel { {"axis"}, {"Out"}); } } + if (Type() == "elementwise_div") { + if (ctx.InputVar("X")->IsType()) { + return framework::KernelSignature("elementwise_div", {"X", "Y"}, + {"axis"}, {"Out"}); + } + } return framework::KernelSignature("None", {"X"}, {}, {"Out"}); } }; diff --git a/paddle/pten/api/include/math.h b/paddle/pten/api/include/math.h index a49d6c116ab..cdc9db55d95 100644 --- a/paddle/pten/api/include/math.h +++ b/paddle/pten/api/include/math.h @@ -26,5 +26,7 @@ PD_DLL_DECL Tensor mean(const Tensor& x); PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y); PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y); + +PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y); } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/math.cc b/paddle/pten/api/lib/math.cc index b7fcdd027cf..d85d5e66d03 100644 --- a/paddle/pten/api/lib/math.cc +++ b/paddle/pten/api/lib/math.cc @@ -137,6 +137,41 @@ PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) { return out; } + +PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) { + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "elementwise_div", kernel_key); + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(dev_ctx); + + // 3. Auto data transform + auto dense_x = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(dense_x); + auto dense_y = std::dynamic_pointer_cast(y.impl()); + kernel_context.EmplaceBackInput(dense_y); + kernel_context.EmplaceBackAttr(-1); + + // 4. InferShape + auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1); + + // 5. Prepare outputs + Tensor out; + const auto allocator = std::make_shared( + pten::TransToFluidPlace(kernel_key.backend())); + auto dense_out = std::make_shared(allocator, out_meta); + kernel_context.EmplaceBackOutput(dense_out); + out.set_impl(dense_out); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} } // namespace experimental } // namespace paddle diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index 72894dc74ba..ec0bde16129 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -75,10 +75,10 @@ DenseTensor Scale(const ContextT& dev_ctx, } template -DenseTensor ElementwiseAdd(const ContextT& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int axis) { +DenseTensor Add(const ContextT& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis) { auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis); const auto allocator = std::make_shared( @@ -102,4 +102,17 @@ DenseTensor Subtract(const ContextT& dev_ctx, return dense_out; } +template +DenseTensor Divide(const ContextT& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis) { + auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + ElementwiseDiv(dev_ctx, x, y, axis, &dense_out); + return dense_out; +} } // namespace pten diff --git a/paddle/pten/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc index 9b91aa347a4..68378170c45 100644 --- a/paddle/pten/kernels/cpu/math.cc +++ b/paddle/pten/kernels/cpu/math.cc @@ -114,6 +114,30 @@ void ElementwiseSub(const CPUContext& dev_ctx, } } +template +void ElementwiseDiv(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + // allocate memory for out + out->mutable_data(); + if (x.dims() == y.dims() && std::is_floating_point::value) { + SameDimsElementwiseCompute>()( + dev_ctx, x, y, out); + } else { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseCompute, T>( + dev_ctx, x, y, axis, general::DivFunctor(), out); + } else { + ElementwiseCompute, T>( + dev_ctx, x, y, axis, general::InverseDivFunctor(), out); + } + } +} + } // namespace pten // TODO(chenweihang): replace by better impl @@ -174,3 +198,13 @@ PT_REGISTER_KERNEL("elementwise_sub", int64_t, complex64, complex128) {} +PT_REGISTER_KERNEL("elementwise_div", + CPU, + ANY, + pten::ElementwiseDiv, + float, + double, + int, + int64_t, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/cpu/math.h b/paddle/pten/kernels/cpu/math.h index 2cbf14c5f87..7495b838ff4 100644 --- a/paddle/pten/kernels/cpu/math.h +++ b/paddle/pten/kernels/cpu/math.h @@ -60,4 +60,10 @@ void ElementwiseSub(const CPUContext& dev_ctx, int axis, DenseTensor* out); +template +void ElementwiseDiv(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); } // namespace pten diff --git a/paddle/pten/kernels/cuda/math.cu b/paddle/pten/kernels/cuda/math.cu index 92a1eeef923..ca84e92c4c7 100644 --- a/paddle/pten/kernels/cuda/math.cu +++ b/paddle/pten/kernels/cuda/math.cu @@ -158,6 +158,23 @@ void ElementwiseSub(const CUDAContext& dev_ctx, dev_ctx, inputs, &outputs, axis, general::SubFunctor()); } +template +void ElementwiseDiv(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + std::vector inputs; + std::vector outputs; + inputs.emplace_back(&x); + inputs.emplace_back(&y); + // allocate memory for out + out->mutable_data(); + outputs.emplace_back(out); + LaunchElementwiseCudaKernel( + dev_ctx, inputs, &outputs, axis, general::DivFunctor()); +} + } // namespace pten // TODO(chenweihang): replace by better impl @@ -217,3 +234,14 @@ PT_REGISTER_KERNEL("elementwise_sub", float16, complex64, complex128) {} +PT_REGISTER_KERNEL("elementwise_div", + CUDA, + ANY, + pten::ElementwiseDiv, + float, + double, + int, + int64_t, + float16, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/cuda/math.h b/paddle/pten/kernels/cuda/math.h index 3d66991d6fd..9f70edac968 100644 --- a/paddle/pten/kernels/cuda/math.h +++ b/paddle/pten/kernels/cuda/math.h @@ -62,6 +62,13 @@ void ElementwiseSub(const CUDAContext& dev_ctx, int axis, DenseTensor* out); +template +void ElementwiseDiv(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + } // namespace pten #endif diff --git a/paddle/pten/kernels/functions/blas/elementwise.h b/paddle/pten/kernels/functions/blas/elementwise.h index 7c137e466d3..34946dcbf8e 100644 --- a/paddle/pten/kernels/functions/blas/elementwise.h +++ b/paddle/pten/kernels/functions/blas/elementwise.h @@ -38,5 +38,14 @@ void ElementwiseSub(const DevCtx& dev_ctx, blas.VSUB(x.numel(), x.data(), y.data(), out->mutable_data()); } +template +void ElementwiseDiv(const DevCtx& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + auto blas = paddle::operators::math::GetBlas(dev_ctx); + blas.VDIV(x.numel(), x.data(), y.data(), out->mutable_data()); +} + } // namespace blas } // namespace pten diff --git a/paddle/pten/kernels/functions/general/elementwise_functor.h b/paddle/pten/kernels/functions/general/elementwise_functor.h index 2342b68f188..f0d4305ea6c 100644 --- a/paddle/pten/kernels/functions/general/elementwise_functor.h +++ b/paddle/pten/kernels/functions/general/elementwise_functor.h @@ -114,5 +114,65 @@ struct InverseSubFunctor { inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b - a; } }; +// Divide +template +struct SameDimsDivFunctor { + void operator()(const DevCtx& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* z); +}; + +template +struct SameDimsDivFunctor< + DevCtx, + T, + typename std::enable_if::value>::type> { + void operator()(const DevCtx& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* z) { + paddle::platform::errors::InvalidArgument( + "If use SameDimsDivFunctor, template args(T) must be floating point. "); + } +}; + +template +struct SameDimsDivFunctor< + DevCtx, + T, + typename std::enable_if::value>::type> { + void operator()(const DevCtx& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* z) { + blas::ElementwiseDiv(dev_ctx, x, y, z); + } +}; + +#define DIV_ERROR_INFO \ + "InvalidArgumentError: Integer division by zero encountered in " \ + "(floor) divide. Please check the input value." + +template +struct DivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; } +}; + +template +struct DivFunctor::value>::type> { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + // For int32/int64, need to check whether the divison is zero. + PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); + return a / b; + } +}; + +template +struct InverseDivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b / a; } +}; + } // namespace general } // namespace pten diff --git a/paddle/pten/tests/api/test_elementwise_api.cc b/paddle/pten/tests/api/test_elementwise_api.cc index e6ac253e1ad..be4e45370c3 100644 --- a/paddle/pten/tests/api/test_elementwise_api.cc +++ b/paddle/pten/tests/api/test_elementwise_api.cc @@ -131,3 +131,57 @@ TEST(API, subtract) { ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f); ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f); } + +// TODO(chenweihang): Remove this test after the API is used in the dygraph +TEST(API, divide) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x->mutable_data(); + + auto dense_y = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({10}), + pten::DataLayout::NCHW)); + auto* dense_y_data = dense_y->mutable_data(); + + float div[3][10] = {0.0}; + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + div[i][j] = (i * 10 + j) * 1.0 / (j * 2.0 + 1); + } + } + for (size_t i = 0; i < 10; ++i) { + dense_y_data[i] = i * 2.0 + 1; + } + paddle::experimental::Tensor x(dense_x); + paddle::experimental::Tensor y(dense_y); + + // 2. test API + auto out = paddle::experimental::divide(x, y); + + // 3. check result + ASSERT_EQ(out.shape().size(), 2UL); + ASSERT_EQ(out.shape()[0], 3); + ASSERT_EQ(out.numel(), 30); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto expect_result = div; + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto actual_result0 = dense_out->data()[0]; + auto actual_result1 = dense_out->data()[1]; + auto actual_result2 = dense_out->data()[10]; + ASSERT_NEAR(expect_result[0][0], actual_result0, 1e-6f); + ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f); + ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f); +} diff --git a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc index 8dafce1fba7..6e9c25e23f0 100644 --- a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc +++ b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc @@ -24,7 +24,7 @@ limitations under the License. */ namespace framework = paddle::framework; using DDim = paddle::framework::DDim; -TEST(DEV_API, elementwise_add) { +TEST(DEV_API, add) { // 1. create tensor const auto alloc = std::make_shared( paddle::platform::CPUPlace()); @@ -56,7 +56,7 @@ TEST(DEV_API, elementwise_add) { auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); // 2. test API - auto dense_out = pten::ElementwiseAdd( + auto dense_out = pten::Add( *(static_cast(dev_ctx)), dense_x, dense_y, @@ -129,3 +129,56 @@ TEST(DEV_API, subtract) { ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f); ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f); } + +TEST(DEV_API, divide) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + + pten::DenseTensor dense_y(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({10}), + pten::DataLayout::NCHW)); + auto* dense_y_data = dense_y.mutable_data(); + + float div[3][10] = {0.0}; + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + div[i][j] = (i * 10 + j) * 1.0 / (j * 2.0 + 1); + } + } + for (size_t i = 0; i < 10; ++i) { + dense_y_data[i] = i * 2.0 + 1; + } + int axis = 1; + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto dense_out = pten::Divide( + *(static_cast(dev_ctx)), + dense_x, + dense_y, + axis); + + // 3. check result + ASSERT_EQ(dense_out.dims().size(), 2); + ASSERT_EQ(dense_out.dims()[0], 3); + ASSERT_EQ(dense_out.meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(dense_out.meta().layout, pten::DataLayout::NCHW); + + auto expect_result = div; + auto actual_result0 = dense_out.data()[0]; + auto actual_result1 = dense_out.data()[1]; + auto actual_result2 = dense_out.data()[10]; + ASSERT_NEAR(expect_result[0][0], actual_result0, 1e-6f); + ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f); + ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f); +} -- GitLab