From 1b491818ab833e407a749ba640f29d964ebba80e Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 24 Mar 2022 10:34:14 +0800 Subject: [PATCH] [Phi] Move mul op kernel into phi (#40833) * add mul phi kernel * remove mul op kernel * remove original mul grad op * fix cinn test * fix dygraph test failed --- .../share_varinfo_into_cinn_pass_test.cc | 2 +- .../paddle2cinn/build_cinn_pass_test.cc | 2 +- .../paddle2cinn/cinn_compiler_test.cc | 2 +- paddle/fluid/imperative/tests/test_eager.cc | 2 +- paddle/fluid/imperative/tests/test_hooks.cc | 6 +- paddle/fluid/imperative/tests/test_layer.cc | 2 +- paddle/fluid/imperative/tests/test_tracer.cc | 8 +- .../inference/tensorrt/convert/test_fc_op.cc | 2 +- .../inference/tensorrt/convert/test_mul_op.cc | 2 +- .../fluid/operators/mkldnn/mul_mkldnn_op.cc | 5 +- paddle/fluid/operators/mul_op.cc | 18 +- paddle/fluid/operators/mul_op.cu.cc | 30 --- paddle/fluid/operators/mul_op.h | 207 ------------------ paddle/fluid/operators/mul_op_npu.cc | 2 +- paddle/fluid/operators/mul_op_xpu.cc | 2 +- paddle/phi/kernels/cpu/matmul_grad_kernel.cc | 14 ++ paddle/phi/kernels/cpu/matmul_kernel.cc | 7 + paddle/phi/kernels/gpu/matmul_grad_kernel.cu | 16 ++ paddle/phi/kernels/gpu/matmul_kernel.cu | 8 + .../kernels/impl/matmul_grad_kernel_impl.h | 159 ++++++++++++++ paddle/phi/kernels/impl/matmul_kernel_impl.h | 30 +++ paddle/phi/kernels/matmul_grad_kernel.h | 24 ++ paddle/phi/kernels/matmul_kernel.h | 10 + paddle/phi/ops/compat/mul_sig.cc | 41 ++++ 24 files changed, 336 insertions(+), 265 deletions(-) delete mode 100644 paddle/fluid/operators/mul_op.cu.cc delete mode 100644 paddle/fluid/operators/mul_op.h create mode 100644 paddle/phi/ops/compat/mul_sig.cc diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc index ed9f623072..60f4e4b309 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc @@ -24,7 +24,7 @@ #include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/program_desc.h" -USE_OP(mul); +USE_OP_ITSELF(mul); USE_OP(cinn_launch); USE_OP_ITSELF(elementwise_add); namespace paddle::framework { diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index 47dffd47b7..c11c7124b6 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -674,7 +674,7 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) { } // namespace paddle USE_PASS(build_cinn_pass); -USE_OP(mul); +USE_OP_ITSELF(mul); USE_OP_ITSELF(relu); USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(relu_grad); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc index cdccc4c554..44f4424d70 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc @@ -300,6 +300,6 @@ TEST(CinnCompilerTest, Compile) { USE_PASS(build_cinn_pass); USE_PASS(graph_viz_pass); -USE_OP(mul); +USE_OP_ITSELF(mul); USE_OP_ITSELF(relu); USE_OP_ITSELF(elementwise_add); diff --git a/paddle/fluid/imperative/tests/test_eager.cc b/paddle/fluid/imperative/tests/test_eager.cc index 7ec21385bb..4a0b99518a 100644 --- a/paddle/fluid/imperative/tests/test_eager.cc +++ b/paddle/fluid/imperative/tests/test_eager.cc @@ -98,4 +98,4 @@ TEST(test_var_helper, eager_var_helper) { } // namespace imperative } // namespace paddle -USE_OP(mul); +USE_OP_ITSELF(mul); diff --git a/paddle/fluid/imperative/tests/test_hooks.cc b/paddle/fluid/imperative/tests/test_hooks.cc index 02a1689c23..eb7e327662 100644 --- a/paddle/fluid/imperative/tests/test_hooks.cc +++ b/paddle/fluid/imperative/tests/test_hooks.cc @@ -28,6 +28,8 @@ PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(matmul_with_flatten, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(matmul_with_flatten_grad, CPU, ALL_LAYOUT); namespace platform = paddle::platform; namespace framework = paddle::framework; @@ -267,7 +269,7 @@ TEST(TestHooks, TestGradVarLeafBackwardHookWithSortedGradAccmulated) { } // namespace imperative } // namespace paddle -USE_OP(mul); -USE_OP(mul_grad); +USE_OP_ITSELF(mul); +USE_OP_ITSELF(mul_grad); USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add_grad); diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index 3fa87d415d..3e5ab9ab96 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -416,4 +416,4 @@ TEST(test_layer, test_eager) { } // namespace imperative } // namespace paddle -USE_OP(mul); +USE_OP_ITSELF(mul); diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index 75876e07fb..1c3a04b51a 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -34,9 +34,13 @@ PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum_grad, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(matmul_with_flatten, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(matmul_with_flatten_grad, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_DECLARE_KERNEL(add_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(matmul_with_flatten, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(matmul_with_flatten_grad, GPU, ALL_LAYOUT); #endif namespace imperative = paddle::imperative; @@ -598,8 +602,8 @@ TEST(test_tracer, eager_tracer) { } // namespace imperative } // namespace paddle -USE_OP(mul); -USE_OP(mul_grad); +USE_OP_ITSELF(mul); +USE_OP_ITSELF(mul_grad); USE_OP_ITSELF(reduce_sum); USE_OP_ITSELF(reduce_sum_grad); USE_OP_ITSELF(elementwise_add); diff --git a/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc index 1ae2668e73..8134d38946 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc @@ -43,4 +43,4 @@ TEST(fc_op, test) { } // namespace tensorrt } // namespace inference } // namespace paddle -USE_OP(mul); +USE_OP_ITSELF(mul); diff --git a/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc b/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc index 282f53559a..86cb7543d4 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc @@ -46,4 +46,4 @@ TEST(MulOpConverter, main) { } // namespace inference } // namespace paddle -USE_OP(mul); +USE_OP_ITSELF(mul); diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index fe9faab7d6..0f70b67bbb 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include -#include "paddle/fluid/operators/mul_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/mkldnn_reuse.h" namespace phi { @@ -46,6 +46,9 @@ using dnnl::memory; using dnnl::prop_kind; using dnnl::stream; +constexpr int kMULMKLDNNINT8 = 1; +constexpr int kMULMKLDNNFP32 = 2; + template class MulPrimitiveFactory { public: diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index bc57b42912..6738f15ef7 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/mul_op.h" #include #include #include #include +#include "paddle/fluid/framework/op_registry.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -27,6 +27,9 @@ namespace operators { using framework::OpKernelType; using framework::Tensor; +constexpr int kMULMKLDNNINT8 = 1; +constexpr int kMULMKLDNNFP32 = 2; + class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -354,16 +357,3 @@ REGISTER_OPERATOR(mul_grad, ops::MulGradOp, ops::MulDoubleGradMaker); REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp); - -REGISTER_OP_CPU_KERNEL( - mul, ops::MulKernel, - ops::MulKernel); - -REGISTER_OP_CPU_KERNEL( - mul_grad, ops::MulGradKernel, - ops::MulGradKernel); - -REGISTER_OP_CPU_KERNEL( - mul_grad_grad, - ops::MulDoubleGradKernel, - ops::MulDoubleGradKernel); diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc deleted file mode 100644 index 6e841712b9..0000000000 --- a/paddle/fluid/operators/mul_op.cu.cc +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/mul_op.h" -#include "paddle/fluid/platform/float16.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, - ops::MulKernel, - ops::MulKernel); -REGISTER_OP_CUDA_KERNEL( - mul_grad, ops::MulGradKernel, - ops::MulGradKernel, - ops::MulGradKernel); -REGISTER_OP_CUDA_KERNEL( - mul_grad_grad, - ops::MulDoubleGradKernel, - ops::MulDoubleGradKernel); diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h deleted file mode 100644 index ce91c6dd0e..0000000000 --- a/paddle/fluid/operators/mul_op.h +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -constexpr int kMULMKLDNNINT8 = 1; -constexpr int kMULMKLDNNFP32 = 2; - -template -class MulKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - const Tensor* y = context.Input("Y"); - Tensor* z = context.Output("Out"); - const Tensor x_matrix = - x->dims().size() > 2 - ? framework::ReshapeToMatrix( - *x, context.template Attr("x_num_col_dims")) - : *x; - const Tensor y_matrix = - y->dims().size() > 2 - ? framework::ReshapeToMatrix( - *y, context.template Attr("y_num_col_dims")) - : *y; - - z->mutable_data(context.GetPlace()); - auto z_dim = z->dims(); - if (z_dim.size() != 2) { - z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - - auto blas = phi::funcs::GetBlas(context); - - blas.MatMul(x_matrix, y_matrix, z); - if (z_dim.size() != 2) { - z->Resize(z_dim); - } - } -}; - -template -class MulGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - int x_num_col_dims = ctx.template Attr("x_num_col_dims"); - int y_num_col_dims = ctx.template Attr("y_num_col_dims"); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto x_matrix = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, x_num_col_dims) - : static_cast(*x); - auto y_matrix = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, y_num_col_dims) - : static_cast(*y); - auto* dout = ctx.Input(framework::GradVarName("Out")); - - Tensor dout_mat; - dout_mat.ShareDataWith(*dout); - dout_mat.Resize({phi::flatten_to_2d(x->dims(), x_num_col_dims)[0], - phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); - - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - if (dx != nullptr) { - dx->set_lod(x->lod()); - } - if (dy != nullptr) { - dy->set_lod(y->lod()); - } - - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - if (dx) { - dx->mutable_data(ctx.GetPlace()); - Tensor dx_matrix = dx->dims().size() > 2 - ? framework::ReshapeToMatrix(*dx, x_num_col_dims) - : *dx; - - // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - Tensor dy_matrix = dy->dims().size() > 2 - ? framework::ReshapeToMatrix(*dy, y_num_col_dims) - : *dy; - // dy = x' * dout. dy K x N, dout : M x N, x : M x K - blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); - } - } -}; - -template -class MulDoubleGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - int x_num_col_dims = ctx.template Attr("x_num_col_dims"); - int y_num_col_dims = ctx.template Attr("y_num_col_dims"); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto x_mat = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, x_num_col_dims) - : static_cast(*x); - auto y_mat = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, y_num_col_dims) - : static_cast(*y); - - const int m = phi::flatten_to_2d(x->dims(), x_num_col_dims)[0]; - const int n = phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]; - - auto* dout = ctx.Input("DOut"); - Tensor dout_mat; - dout_mat.ShareDataWith(*dout); - dout_mat.Resize({m, n}); - - auto* ddx = ctx.Input("DDX"); - auto* ddy = ctx.Input("DDY"); - - auto* dx = ctx.Output("DX"); - auto* dy = ctx.Output("DY"); - auto* ddout = ctx.Output("DDOut"); - - Tensor ddout_mat; - if (ddout) { - ddout->set_lod(dout->lod()); - // allocate and reshape ddout - ddout->mutable_data(ctx.GetPlace()); - ddout_mat.ShareDataWith(*ddout); - ddout_mat.Resize({m, n}); - } - - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - // a flag to specify whether ddout value has been set, if flag - // is false, MatMul beta should be 0 to set ddout, if flag is - // true, MatMul beta should be 1 to add result to ddout. - bool ddout_flag = false; - if (ddx) { - auto ddx_mat = ddx->dims().size() > 2 - ? framework::ReshapeToMatrix(*ddx, x_num_col_dims) - : static_cast(*ddx); - - // dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N - if (dy) { - dy->set_lod(y->lod()); - // allocate and reshape dy - dy->mutable_data(ctx.GetPlace()); - Tensor dy_mat = dy->dims().size() > 2 - ? framework::ReshapeToMatrix(*dy, y_num_col_dims) - : *dy; - blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat); - } - // ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N - if (ddout) { - blas.MatMul(ddx_mat, false, y_mat, false, static_cast(1.0), - &ddout_mat, static_cast(ddout_flag)); - ddout_flag = true; - } - } - if (ddy) { - auto ddy_mat = ddy->dims().size() > 2 - ? framework::ReshapeToMatrix(*ddy, y_num_col_dims) - : static_cast(*ddy); - // dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K - if (dx) { - dx->set_lod(x->lod()); - // allocate and reshape dx - dx->mutable_data(ctx.GetPlace()); - Tensor dx_mat = dx->dims().size() > 2 - ? framework::ReshapeToMatrix(*dx, x_num_col_dims) - : *dx; - blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat); - } - // ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N - if (ddout) { - blas.MatMul(x_mat, false, ddy_mat, false, static_cast(1.0), - &ddout_mat, static_cast(ddout_flag)); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/mul_op_npu.cc b/paddle/fluid/operators/mul_op_npu.cc index e1fb5f4f6b..2aedfed9f8 100644 --- a/paddle/fluid/operators/mul_op_npu.cc +++ b/paddle/fluid/operators/mul_op_npu.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/mul_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/mul_op_xpu.cc b/paddle/fluid/operators/mul_op_xpu.cc index 1fdaa27299..6ef41e059c 100644 --- a/paddle/fluid/operators/mul_op_xpu.cc +++ b/paddle/fluid/operators/mul_op_xpu.cc @@ -14,11 +14,11 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/mul_op.h" #include #include #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { diff --git a/paddle/phi/kernels/cpu/matmul_grad_kernel.cc b/paddle/phi/kernels/cpu/matmul_grad_kernel.cc index c68e8115e8..aba519ff04 100644 --- a/paddle/phi/kernels/cpu/matmul_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/matmul_grad_kernel.cc @@ -45,3 +45,17 @@ PD_REGISTER_KERNEL(matmul_triple_grad, double, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(matmul_with_flatten_grad, + CPU, + ALL_LAYOUT, + phi::MatmulWithFlattenGradKernel, + float, + double) {} + +PD_REGISTER_KERNEL(matmul_with_flatten_double_grad, + CPU, + ALL_LAYOUT, + phi::MatmulWithFlattenDoubleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/matmul_kernel.cc b/paddle/phi/kernels/cpu/matmul_kernel.cc index 2bf56c07a5..8aa25c0da0 100644 --- a/paddle/phi/kernels/cpu/matmul_kernel.cc +++ b/paddle/phi/kernels/cpu/matmul_kernel.cc @@ -28,3 +28,10 @@ PD_REGISTER_KERNEL(matmul, double, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(matmul_with_flatten, + CPU, + ALL_LAYOUT, + phi::MatmulWithFlattenKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/gpu/matmul_grad_kernel.cu index ff23ebd05b..9c80d5e151 100644 --- a/paddle/phi/kernels/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_grad_kernel.cu @@ -49,3 +49,19 @@ PD_REGISTER_KERNEL(matmul_triple_grad, phi::dtype::float16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(matmul_with_flatten_grad, + GPU, + ALL_LAYOUT, + phi::MatmulWithFlattenGradKernel, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(matmul_with_flatten_double_grad, + GPU, + ALL_LAYOUT, + phi::MatmulWithFlattenDoubleGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/matmul_kernel.cu b/paddle/phi/kernels/gpu/matmul_kernel.cu index 98be79c5f9..20c9a5229a 100644 --- a/paddle/phi/kernels/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_kernel.cu @@ -30,3 +30,11 @@ PD_REGISTER_KERNEL(matmul, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(matmul_with_flatten, + GPU, + ALL_LAYOUT, + phi::MatmulWithFlattenKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h index 495b93f2a4..25a9db868d 100644 --- a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h @@ -1731,4 +1731,163 @@ void MatmulTripleGradKernel(const Context& dev_ctx, } } +template +void MatmulWithFlattenGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto x_matrix = x.dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(x, x_num_col_dims) + : x; + auto y_matrix = y.dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(y, y_num_col_dims) + : y; + auto* dout = &out_grad; + + DenseTensor dout_mat(*dout); + dout_mat.Resize({phi::flatten_to_2d(x.dims(), x_num_col_dims)[0], + phi::flatten_to_2d(y.dims(), y_num_col_dims)[1]}); + + auto* dx = x_grad; + auto* dy = y_grad; + + if (dx != nullptr) { + dx->set_lod(x.lod()); + } + if (dy != nullptr) { + dy->set_lod(y.lod()); + } + + auto blas = phi::funcs::GetBlas(dev_ctx); + if (dx) { + dev_ctx.template Alloc(dx); + DenseTensor dx_matrix = + dx->dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(*dx, x_num_col_dims) + : *dx; + + // dx = dout * y'. dx: M x K, dout : M x N, y : K x N + blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); + } + if (dy) { + dev_ctx.template Alloc(dy); + DenseTensor dy_matrix = + dy->dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(*dy, y_num_col_dims) + : *dy; + // dy = x' * dout. dy K x N, dout : M x N, x : M x K + blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); + } +} + +template +void MatmulWithFlattenDoubleGradKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + paddle::optional x_grad_grad, + paddle::optional y_grad_grad, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* x_grad, + DenseTensor* y_grad, + DenseTensor* out_grad_grad) { + auto x_mat = x.dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(x, x_num_col_dims) + : x; + auto y_mat = y.dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(y, y_num_col_dims) + : y; + + const int m = phi::flatten_to_2d(x.dims(), x_num_col_dims)[0]; + const int n = phi::flatten_to_2d(y.dims(), y_num_col_dims)[1]; + + auto* dout = &out_grad; + DenseTensor dout_mat(*dout); + dout_mat.Resize({m, n}); + + auto* ddx = x_grad_grad.get_ptr(); + auto* ddy = y_grad_grad.get_ptr(); + + auto* dx = x_grad; + auto* dy = y_grad; + auto* ddout = out_grad_grad; + + DenseTensor ddout_mat; + if (ddout) { + ddout->set_lod(dout->lod()); + // allocate and reshape ddout + dev_ctx.template Alloc(ddout); + ddout_mat.ShareDataWith(*ddout); + ddout_mat.Resize({m, n}); + } + + auto blas = phi::funcs::GetBlas(dev_ctx); + // a flag to specify whether ddout value has been set, if flag + // is false, MatMul beta should be 0 to set ddout, if flag is + // true, MatMul beta should be 1 to add result to ddout. + bool ddout_flag = false; + if (ddx) { + auto ddx_mat = + ddx->dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(*ddx, x_num_col_dims) + : static_cast(*ddx); + + // dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N + if (dy) { + dy->set_lod(y.lod()); + // allocate and reshape dy + dev_ctx.template Alloc(dy); + DenseTensor dy_mat = + dy->dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(*dy, y_num_col_dims) + : *dy; + blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat); + } + // ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N + if (ddout) { + blas.MatMul(ddx_mat, + false, + y_mat, + false, + static_cast(1.0), + &ddout_mat, + static_cast(ddout_flag)); + ddout_flag = true; + } + } + if (ddy) { + auto ddy_mat = + ddy->dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(*ddy, y_num_col_dims) + : static_cast(*ddy); + // dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K + if (dx) { + dx->set_lod(x.lod()); + // allocate and reshape dx + dev_ctx.template Alloc(dx); + DenseTensor dx_mat = + dx->dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(*dx, x_num_col_dims) + : *dx; + blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat); + } + // ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N + if (ddout) { + blas.MatMul(x_mat, + false, + ddy_mat, + false, + static_cast(1.0), + &ddout_mat, + static_cast(ddout_flag)); + } + } +} + } // namespace phi diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index f6136de5d8..3201923e1b 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -506,4 +506,34 @@ void MatmulKernel(const Context& dev_ctx, MatMulFunction(dev_ctx, x, y, out, transpose_x, transpose_y); } +template +void MatmulWithFlattenKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out) { + const DenseTensor x_matrix = + x.dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(x, x_num_col_dims) + : x; + const DenseTensor y_matrix = + y.dims().size() > 2 + ? paddle::framework::ReshapeToMatrix(y, y_num_col_dims) + : y; + + dev_ctx.template Alloc(out); + auto z_dim = out->dims(); + if (z_dim.size() != 2) { + out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + + auto blas = phi::funcs::GetBlas(dev_ctx); + + blas.MatMul(x_matrix, y_matrix, out); + if (z_dim.size() != 2) { + out->Resize(z_dim); + } +} + } // namespace phi diff --git a/paddle/phi/kernels/matmul_grad_kernel.h b/paddle/phi/kernels/matmul_grad_kernel.h index 10452ff0b7..41a835db46 100644 --- a/paddle/phi/kernels/matmul_grad_kernel.h +++ b/paddle/phi/kernels/matmul_grad_kernel.h @@ -60,4 +60,28 @@ void MatmulTripleGradKernel(const Context& dev_ctx, DenseTensor* out_d_ddx, DenseTensor* out_d_ddy); +template +void MatmulWithFlattenGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* x_grad, + DenseTensor* y_grad); + +template +void MatmulWithFlattenDoubleGradKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + paddle::optional x_grad_grad, + paddle::optional y_grad_grad, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* x_grad, + DenseTensor* y_grad, + DenseTensor* out_grad_grad); + } // namespace phi diff --git a/paddle/phi/kernels/matmul_kernel.h b/paddle/phi/kernels/matmul_kernel.h index b524b9e586..a4c4971499 100644 --- a/paddle/phi/kernels/matmul_kernel.h +++ b/paddle/phi/kernels/matmul_kernel.h @@ -29,6 +29,16 @@ void MatmulKernel(const Context& dev_ctx, bool transpose_y, DenseTensor* out); +// In order to be compatible with `mul` op in fluid, +// it is no longer used in 2.x API +template +void MatmulWithFlattenKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out); + template DenseTensor Matmul(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/ops/compat/mul_sig.cc b/paddle/phi/ops/compat/mul_sig.cc new file mode 100644 index 0000000000..8770db1039 --- /dev/null +++ b/paddle/phi/ops/compat/mul_sig.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("matmul_with_flatten_grad", + {"X", "Y", GradVarName("Out")}, + {"x_num_col_dims", "y_num_col_dims"}, + {GradVarName("X"), GradVarName("Y")}); +} + +KernelSignature MulDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("matmul_with_flatten_double_grad", + {"X", "Y", "DOut", "DDX", "DDY"}, + {"x_num_col_dims", "y_num_col_dims"}, + {"DX", "DY", "DDOut"}); +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(mul, matmul_with_flatten); +PD_REGISTER_BASE_KERNEL_NAME(mul_grad, matmul_with_flatten_grad); +PD_REGISTER_BASE_KERNEL_NAME(mul_grad_grad, matmul_with_flatten_double_grad); + +PD_REGISTER_ARG_MAPPING_FN(mul_grad, phi::MulGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(mul_grad_grad, phi::MulDoubleGradOpArgumentMapping); -- GitLab