diff --git a/paddle/fluid/operators/bmm_op.cc b/paddle/fluid/operators/bmm_op.cc index 8cacc3c4f2277ae8a40f9990479c652e7b08948a..305236134dbe18f8f318064800fe4a7b2c95796b 100644 --- a/paddle/fluid/operators/bmm_op.cc +++ b/paddle/fluid/operators/bmm_op.cc @@ -16,6 +16,11 @@ #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" + namespace paddle { namespace operators { @@ -24,62 +29,6 @@ class BmmOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), - true, - platform::errors::NotFound("Input(X) of BmmOp should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Y"), - true, - platform::errors::NotFound("Input(Y) of BmmOp should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), - true, - platform::errors::NotFound("Output(Out) of BmmOp should not be null.")); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - PADDLE_ENFORCE_EQ(x_dims.size(), - 3, - platform::errors::InvalidArgument( - "Input(X) of BmmOp must be 3-dimensional in BmmOp, " - "but received X's shape: [%s].", - x_dims)); - PADDLE_ENFORCE_EQ(y_dims.size(), - 3, - platform::errors::InvalidArgument( - "Input(Y) of BmmOp must be 3-dimensional in BmmOp, " - "but received Y's shape: [%s].", - y_dims)); - PADDLE_ENFORCE_EQ( - x_dims[0], - y_dims[0], - platform::errors::InvalidArgument( - "Input(X) and Input(Y) must have the same batch size in BmmOp, " - "but received X's batch size: [%s]," - "Y's batch size [%s]", - x_dims[0], - y_dims[0])); - PADDLE_ENFORCE_EQ( - x_dims[2], - y_dims[1], - platform::errors::InvalidArgument( - "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," - "but receive X's width: [%s]," - "Y's height: [%s].", - x_dims[2], - y_dims[1])); - - std::vector dim_out; - dim_out.push_back(x_dims[0]); - dim_out.push_back(x_dims[1]); - dim_out.push_back(y_dims[2]); - ctx->SetOutputDim("Out", phi::make_ddim(dim_out)); - ctx->ShareLoD("X", /*->*/ "Out"); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -110,33 +59,6 @@ class BmmOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), - true, - platform::errors::NotFound("Input(X) of BmmOp should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Y"), - true, - platform::errors::NotFound("Input(Y) of BmmOp should not be null")); - PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), - true, - platform::errors::NotFound( - "Output(Out@GRAD) of BmmOp should not be null.")); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); - } - } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( @@ -166,15 +88,16 @@ class BmmOpGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(bmm, + BmmInferShapeFunctor, + PD_INFER_META(phi::BmmInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(bmm_grad, + BmmGradInferShapeFunctor, + PD_INFER_META(phi::BmmGradInferMeta)); REGISTER_OPERATOR(bmm, ops::BmmOp, ops::BmmOpMaker, ops::BmmOpGradMaker, - ops::BmmOpGradMaker); -REGISTER_OPERATOR(bmm_grad, ops::BmmOpGrad); -REGISTER_OP_CPU_KERNEL(bmm, - ops::BmmKernel, - ops::BmmKernel); -REGISTER_OP_CPU_KERNEL(bmm_grad, - ops::BmmGradKernel, - ops::BmmGradKernel); + ops::BmmOpGradMaker, + BmmInferShapeFunctor); +REGISTER_OPERATOR(bmm_grad, ops::BmmOpGrad, BmmGradInferShapeFunctor); diff --git a/paddle/fluid/operators/bmm_op.cu b/paddle/fluid/operators/bmm_op.cu deleted file mode 100644 index c3e03299daf4881a35ab13c2138f0fb34a89d1a2..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/bmm_op.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/bmm_op.h" - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - bmm, - ops::BmmKernel, - ops::BmmKernel, - ops::BmmKernel); - -REGISTER_OP_CUDA_KERNEL( - bmm_grad, - ops::BmmGradKernel, - ops::BmmGradKernel, - ops::BmmGradKernel); -#endif diff --git a/paddle/fluid/operators/bmm_op.h b/paddle/fluid/operators/bmm_op.h index f8ebdcd4d8f6df667feb391f6beeef3f01c3906f..110cd2d2810d8091bf62c42293aba3575c6fafb9 100644 --- a/paddle/fluid/operators/bmm_op.h +++ b/paddle/fluid/operators/bmm_op.h @@ -58,95 +58,6 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } -template -class BmmKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - const Tensor &x = *context.Input("X"); - const Tensor &y = *context.Input("Y"); - Tensor *out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - - if (x.numel() == 0 || y.numel() == 0) { - return; - } - - auto blas = phi::funcs::GetBlas(context); - - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(x.dims(), 0, false); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(y.dims(), 0, false); - - // auto scale = static_cast(context.Attr("alpha")); - blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); - } -}; - -template -class BmmGradKernel : public framework::OpKernel { - public: - void MatMul(const framework::ExecutionContext &context, - const framework::Tensor &a, - bool trans_a, - const framework::Tensor &b, - bool trans_b, - framework::Tensor *out) const { - out->mutable_data(context.GetPlace()); - auto blas = phi::funcs::GetBlas(context); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); - - blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); - } - void CalcInputGrad(const framework::ExecutionContext &context, - const framework::Tensor &a, - bool trans_a, - const framework::Tensor &b, - bool trans_b, - framework::Tensor *out) const { - if (out == nullptr) return; - MatMul(context, a, trans_a, b, trans_b, out); - } - void Compute(const framework::ExecutionContext &context) const override { - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = - *context.Input(framework::GradVarName("Out")); - auto *dx = context.Output(framework::GradVarName("X")); - auto *dy = context.Output(framework::GradVarName("Y")); - - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, false, false); - framework::DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - } - - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - CalcInputGrad(context, dout, false, y, true, dx); - CalcInputGrad(context, x, true, dout, false, dy); - - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - } -}; - } // namespace operators } // namespace paddle #endif // PADDLE_FLUID_OPERATORS_BMM_OP_H_ diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 3480af8db88d3ab116e1c64ea13c05e9eecd6be8..1eca092a5f22f0fdc15b97eaadb288e456bfdda3 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -73,6 +73,17 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x, } } +void BmmGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& out_grad, + MetaTensor* x_grad, + MetaTensor* y_grad) { + x_grad->set_dims(x.dims()); + y_grad->set_dims(y.dims()); + x_grad->set_dtype(x.dtype()); + y_grad->set_dtype(y.dtype()); +} + void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, int groups, const std::string& data_format, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 88825faa95f7c323de1057192f4ca788cd15fec4..5551b6bcbf183b789c2bea7698e66cce1f933bfe 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -41,6 +41,12 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x, MetaTensor* dweight, MetaTensor* dbias); +void BmmGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& out_grad, + MetaTensor* x_grad, + MetaTensor* y_grad); + void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, int groups, const std::string& data_format, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 1bbcd52e8bee4b1cce806f228358e767f627e06f..1463296664b461be1fd92544604e66437555a6aa 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -260,6 +260,53 @@ void BincountInferMeta(const MetaTensor& x, out->share_lod(x); } +void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { + std::vector x_dims = phi::vectorize(x.dims()); + std::vector y_dims = phi::vectorize(y.dims()); + std::size_t x_ndims = x_dims.size(); + std::size_t y_ndims = y_dims.size(); + + PADDLE_ENFORCE_EQ(x_ndims, + 3, + phi::errors::InvalidArgument( + "Input(X) of BmmOp must be 3-dimensional in BmmOp, " + "but received X's shape: [%s].", + x_ndims)); + PADDLE_ENFORCE_EQ(y_ndims, + 3, + phi::errors::InvalidArgument( + "Input(Y) of BmmOp must be 3-dimensional in BmmOp, " + "but received Y's shape: [%s].", + y_ndims)); + PADDLE_ENFORCE_EQ( + x_dims[0], + y_dims[0], + phi::errors::InvalidArgument( + "Input(X) and Input(Y) must have the same batch size in BmmOp, " + "but received X's batch size: [%s]," + "Y's batch size [%s]", + x_dims[0], + y_dims[0])); + PADDLE_ENFORCE_EQ( + x_dims[2], + y_dims[1], + phi::errors::InvalidArgument( + "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," + "but receive X's width: [%s]," + "Y's height: [%s].", + x_dims[2], + y_dims[1])); + + std::vector dim_out; + dim_out.push_back(x_dims[0]); + dim_out.push_back(x_dims[1]); + dim_out.push_back(y_dims[2]); + out->set_dims(phi::make_ddim(dim_out)); + out->share_lod(x); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + void CholeskySolveInferMeta(const MetaTensor& x, const MetaTensor& y, bool upper, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 70dafe24fbd731d6a3157f497caf254beddebbbf..85851ee705d204a71afcff5aae1a13ce16921e29 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -60,6 +60,8 @@ void BincountInferMeta(const MetaTensor& x, int minlength, MetaTensor* out); +void BmmInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); + void CholeskySolveInferMeta(const MetaTensor& x, const MetaTensor& y, bool upper, diff --git a/paddle/phi/kernels/bmm_grad_kernel.h b/paddle/phi/kernels/bmm_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1b56641d113c5767a007c22aa48638d55dfa8b70 --- /dev/null +++ b/paddle/phi/kernels/bmm_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BmmGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/bmm_kernel.h b/paddle/phi/kernels/bmm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..09e7f9647b68ebd7209c205881d0b6e599c685ae --- /dev/null +++ b/paddle/phi/kernels/bmm_kernel.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +/** + * @brief Bmm Kernel. + * Applies batched matrix multiplication to two tensors. + * + * Both of the two input tensors must be three-dementional + * and share the same batch size. + * if x is a (b, m, k) tensor, y is a (b, k, n) tensor, + * the output will be a (b, m, n) tensor. + * + * @param ctx device context + * @param x The input tensor + * @param y The input tensor + * @param out The product Tensor + */ +template +void BmmKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/bmm_grad_kernel.cc b/paddle/phi/kernels/cpu/bmm_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..b8d305666626b48bc51203e91a7eb0be5dcbfbaf --- /dev/null +++ b/paddle/phi/kernels/cpu/bmm_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bmm_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/bmm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + bmm_grad, CPU, ALL_LAYOUT, phi::BmmGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/bmm_kernel.cc b/paddle/phi/kernels/cpu/bmm_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bc8930f3ff27d412d581aba7fa71df74617ed86 --- /dev/null +++ b/paddle/phi/kernels/cpu/bmm_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bmm_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/bmm_kernel_impl.h" + +PD_REGISTER_KERNEL(bmm, CPU, ALL_LAYOUT, phi::BmmKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/bmm_grad_kernel.cu b/paddle/phi/kernels/gpu/bmm_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..1e51d373e097a0bb5ede62a1ef51eed6951c41a6 --- /dev/null +++ b/paddle/phi/kernels/gpu/bmm_grad_kernel.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bmm_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/bmm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(bmm_grad, + GPU, + ALL_LAYOUT, + phi::BmmGradKernel, + float, + double, + paddle::platform::float16) {} diff --git a/paddle/phi/kernels/gpu/bmm_kernel.cu b/paddle/phi/kernels/gpu/bmm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..36dfad3d6a037cfd17d16979e343ae409d927007 --- /dev/null +++ b/paddle/phi/kernels/gpu/bmm_kernel.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bmm_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/bmm_kernel_impl.h" + +PD_REGISTER_KERNEL(bmm, + GPU, + ALL_LAYOUT, + phi::BmmKernel, + float, + double, + paddle::platform::float16) {} diff --git a/paddle/phi/kernels/impl/bmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/bmm_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..33ca7233a37268d9bed2e57556fef9f4ca8e0b0a --- /dev/null +++ b/paddle/phi/kernels/impl/bmm_grad_kernel_impl.h @@ -0,0 +1,96 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/bmm_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" + +namespace phi { + +template +void MatMul(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out) { + dev_ctx.template Alloc(out); + auto blas = phi::funcs::GetBlas(dev_ctx); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); + + blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); +} + +template +void CalcInputGrad(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out) { + if (out == nullptr) return; + MatMul(dev_ctx, a, trans_a, b, trans_b, out); +} + +template +void BmmGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor out_grad_help = out_grad; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &out_grad_help, false, false); + + phi::DDim dx_dims; + if (x_grad) { + dx_dims = x_grad->dims(); + if (dx_dims != x_help.dims()) { + x_grad->Resize(x_help.dims()); + } + } + + phi::DDim dy_dims; + if (y_grad) { + dy_dims = y_grad->dims(); + if (dy_dims != y_help.dims()) { + y_grad->Resize(y_help.dims()); + } + } + + CalcInputGrad( + dev_ctx, out_grad_help, false, y_help, true, x_grad); + CalcInputGrad( + dev_ctx, x_help, true, out_grad_help, false, y_grad); + + if (x_grad) { + if (dx_dims != x_help.dims()) { + x_grad->Resize(dx_dims); + } + } + if (y_grad) { + if (dy_dims != y_help.dims()) { + y_grad->Resize(dy_dims); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/bmm_kernel_impl.h b/paddle/phi/kernels/impl/bmm_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..bff5b730deca908970a06e672a27e77f2a6d2a72 --- /dev/null +++ b/paddle/phi/kernels/impl/bmm_kernel_impl.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/bmm_kernel.h" + +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +template +void BmmKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + dev_ctx.template Alloc(out); + + if (x.numel() == 0 || y.numel() == 0) { + return; + } + + auto blas = phi::funcs::GetBlas(dev_ctx); + + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(x.dims(), 0, false); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(y.dims(), 0, false); + + blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/bmm_sig.cc b/paddle/phi/ops/compat/bmm_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..415a90c3d3b3f8e6b55eadca35ba260bb062ca59 --- /dev/null +++ b/paddle/phi/ops/compat/bmm_sig.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature BmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "bmm_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(bmm_grad, phi::BmmGradOpArgumentMapping);