diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc index 636a594a657cb0744aac161d928ff9078b1f92bc..d02e472cb916f05b3ab1db00c987156112cba7e5 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc @@ -23,7 +23,7 @@ #include "paddle/phi/core/kernel_registry.h" USE_OP_ITSELF(scale); -USE_OP(elementwise_mul); +USE_OP_ITSELF(elementwise_mul); USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add_grad); diff --git a/paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc index d14317712b579b8f04889c3a18e4231d96513225..9c6ea51fe5a356d0dfd1d551819114a6c4549c3c 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc @@ -104,4 +104,4 @@ TEST(elementwise_op, plugin) { } // namespace paddle USE_OP_ITSELF(elementwise_add); -USE_OP(elementwise_mul); +USE_OP_ITSELF(elementwise_mul); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 830e09eeae4811eb44bd4e21e17fe83ee44c592d..45b6f7cb391949043ff4e6725f7e3f0c18eef278 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -20,35 +20,6 @@ limitations under the License. */ namespace paddle { namespace operators { - -template -struct SameDimsElemwiseMul< - platform::CPUDeviceContext, T, - typename std::enable_if::value>::type> { - void operator()(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - framework::Tensor *z) { - auto blas = phi::funcs::GetBlas(ctx); - blas.VMUL(x->numel(), x->data(), y->data(), z->data()); - } -}; - -template -struct SameDimsElemwiseMul< - platform::CPUDeviceContext, T, - typename std::enable_if::value>::type> { - void operator()(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - framework::Tensor *z) { - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_y = framework::EigenVector::Flatten(*y); - auto eigen_z = framework::EigenVector::Flatten(*z); - auto &place = *ctx.template device_context() - .eigen_device(); - eigen_z.device(place) = eigen_x * eigen_y; - } -}; - class ElementwiseMulOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Mul"; } @@ -160,20 +131,6 @@ REGISTER_OPERATOR( REGISTER_OPERATOR(elementwise_mul_triple_grad, ops::ElementwiseOpTripleGrad); -REGISTER_OP_CPU_KERNEL( - elementwise_mul, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel>, - ops::ElementwiseMulKernel>); - REGISTER_OP_VERSION(elementwise_mul) .AddCheckpoint( R"ROC(Register elementwise_mul for adding the attribute of Scale_y)ROC", diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu deleted file mode 100644 index f7b9fd1e265f5d3f107e734f9ffdcc90e7f6cc77..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" -#include "paddle/phi/backends/gpu/gpu_context.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -namespace paddle { -namespace operators { - -template -class ElementwiseMulKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto x_var = ctx.InputVar("X"); - PADDLE_ENFORCE_EQ(x_var != nullptr, true, - platform::errors::InvalidArgument( - "Cannot get input Variable X, Variable name = %s.", - ctx.InputName("X"))); - const auto& cuda_ctx = - ctx.template device_context(); - if (x_var->IsType()) { - framework::Tensor x_for_selectedrows; - std::vector ins; - std::vector outs; - int axis = - PackTensorsIntoVector(ctx, &ins, &outs, &x_for_selectedrows); - paddle::operators::LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, MulFunctor()); - } else if (x_var->IsType()) { - auto* x_lod = ctx.Input("X"); - auto* y_lod = ctx.Input("Y"); - auto* z_lod = ctx.Output("Out"); - z_lod->mutable_data(ctx.GetPlace()); - - int axis = ctx.Attr("axis"); - auto pt_x = paddle::experimental::MakePhiDenseTensor(*x_lod); - auto pt_y = paddle::experimental::MakePhiDenseTensor(*y_lod); - auto pt_z = paddle::experimental::MakePhiDenseTensor(*z_lod); - phi::MultiplyRawKernel(static_cast(cuda_ctx), - *pt_x.get(), *pt_y.get(), axis, pt_z.get()); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "X's type[%s] is not supported by elementwise_op. X's type should be " - "LoDTensor or SelectedRows.", - framework::ToTypeName(x_var->Type()))); - } - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL( - elementwise_mul, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel>, - ops::ElementwiseMulKernel>); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 6f4aba93d56e2a8227a8578067ac934d41243fb6..e2dd0e36d400afe0d91bdcad74b5f9de2a4c8854 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -58,85 +58,5 @@ class ElementwiseMulOp : public ElementwiseOp { } }; -template -void default_elementwise_mul(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { - int axis = ctx.Attr("axis"); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - MulFunctor(), z); - } else { - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, InverseMulFunctor(), z); - } -} - -template -struct SameDimsElemwiseMul { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z); -}; - -template -class ElementwiseMulKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto x_var = ctx.InputVar("X"); - PADDLE_ENFORCE_EQ(x_var != nullptr, true, - platform::errors::InvalidArgument( - "Cannot get input Variable X, Variable name = %s.", - ctx.InputName("X"))); - auto* y = ctx.Input("Y"); - - framework::Tensor x, *z; - if (x_var->IsType()) { - PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true, - platform::errors::InvalidArgument( - "For elementwise_op, if X is Sparse, Y must be " - "scalar. But reveived the size of Y = %s.", - y->dims().size())); - auto& x_sele = x_var->Get(); - auto out_sele = ctx.Output("Out"); - x = x_sele.value(); - out_sele->set_rows(x_sele.rows()); - out_sele->set_height(x_sele.height()); - out_sele->mutable_value()->Resize(x_sele.value().dims()); - out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type()); - z = ctx.Output("Out")->mutable_value(); - z->mutable_data(ctx.GetPlace()); - auto dims_equal = x.dims() == y->dims(); - if (dims_equal) { - SameDimsElemwiseMul same_dims_mul; - same_dims_mul(ctx, &x, y, z); - } else { - default_elementwise_mul(ctx, &x, y, z); - } - } else if (x_var->IsType()) { - auto* x_lod = ctx.Input("X"); - auto* z_lod = ctx.Output("Out"); - z_lod->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.device_context(); - int axis = ctx.Attr("axis"); - auto pt_x = paddle::experimental::MakePhiDenseTensor(*x_lod); - auto pt_y = paddle::experimental::MakePhiDenseTensor(*y); - auto pt_z = paddle::experimental::MakePhiDenseTensor(*z_lod); - phi::MultiplyRawKernel( - static_cast::TYPE&>(dev_ctx), - *pt_x.get(), *pt_y.get(), axis, pt_z.get()); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "X's type[%s] is not supported by elementwise_op. X's type should be " - "LoDTensor or SelectedRows.", - framework::ToTypeName(x_var->Type()))); - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index 23428dd403e9b1ef62007c7b9193ed3b8482cab3..b5fb0c54c7812f6a022d3c61f6e225a576765f00 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -27,7 +27,7 @@ USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); -USE_OP(elementwise_mul); +USE_OP_ITSELF(elementwise_mul); USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); USE_OP_ITSELF(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index 019d4fed5b28eaed72a370bfb51c1a75807964fd..84c379ab280912200329e2a66e1d56c9250d4cde 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -202,6 +202,7 @@ PD_REGISTER_KERNEL(multiply, int64_t, bool, phi::dtype::float16, + phi::dtype::bfloat16, complex64, complex128) {} PD_REGISTER_KERNEL(maximum, diff --git a/paddle/phi/kernels/selected_rows/elementwise_kernel.cc b/paddle/phi/kernels/selected_rows/elementwise_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..7fba3244a60eedc032290aa82f99e9eb01e0ff6b --- /dev/null +++ b/paddle/phi/kernels/selected_rows/elementwise_kernel.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/selected_rows/elementwise_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/elementwise_kernel.h" + +namespace phi { +namespace sr { + +template +void MultiplyRawKernel(const Context& dev_ctx, + const SelectedRows& x, + const DenseTensor& y, + int axis, + SelectedRows* out) { + PADDLE_ENFORCE_EQ(y.dims().size() == 1 && y.dims()[0] == 1, + true, + phi::errors::InvalidArgument( + "For MultiplyKernel, if X is Sparse, Y must be " + "scalar. But reveived the size of Y = %s.", + y.dims().size())); + out->set_rows(x.rows()); + out->set_height(x.height()); + auto z = out->mutable_value(); + z->Resize(x.value().dims()); + dev_ctx.Alloc(z, x.value().dtype()); + MultiplyRawKernel(dev_ctx, x.value(), y, axis, z); +} + +template +void MultiplyKernel(const Context& dev_ctx, + const SelectedRows& x, + const DenseTensor& y, + SelectedRows* out) { + int axis = -1; + MultiplyRawKernel(dev_ctx, x, y, axis, out); +} + +} // namespace sr +} // namespace phi + +using complex64 = ::phi::dtype::complex; +using complex128 = ::phi::dtype::complex; + +PD_REGISTER_KERNEL(multiply_raw_sr, + CPU, + ALL_LAYOUT, + phi::sr::MultiplyRawKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + complex64, + complex128) {} +PD_REGISTER_KERNEL(multiply_sr, + CPU, + ALL_LAYOUT, + phi::sr::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + complex64, + complex128) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(multiply_raw_sr, + GPU, + ALL_LAYOUT, + phi::sr::MultiplyRawKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::float16, + complex64, + complex128) {} +PD_REGISTER_KERNEL(multiply_sr, + GPU, + ALL_LAYOUT, + phi::sr::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::float16, + complex64, + complex128) {} +#endif diff --git a/paddle/phi/kernels/selected_rows/elementwise_kernel.h b/paddle/phi/kernels/selected_rows/elementwise_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f90fd47882b89f7c040b7b7c79c3fae7dec6a7e6 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/elementwise_kernel.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { +namespace sr { + +template +void MultiplyRawKernel(const Context& dev_ctx, + const SelectedRows& x, + const DenseTensor& y, + int axis, + SelectedRows* out); + +template +void MultiplyKernel(const Context& dev_ctx, + const SelectedRows& x, + const DenseTensor& y, + SelectedRows* out); + +} // namespace sr +} // namespace phi diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index 7f00af6f9af86f35709c2d120a7a47917d8d8431..a94e46452205268391adcad8e7bc26da422e6526 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -42,8 +42,12 @@ KernelSignature ElementwiseMulOpArgumentMapping( return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"}); } return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"}); + } else { + if (axis == -1) { + return KernelSignature("multiply_sr", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("multiply_raw_sr", {"X", "Y"}, {"axis"}, {"Out"}); } - return KernelSignature("unregistered", {}, {}, {}); } KernelSignature ElementwiseDivOpArgumentMapping(