From de8f27485ac54cbfe1bb3bd59d2cc68c2e20c335 Mon Sep 17 00:00:00 2001 From: From00 Date: Sat, 26 Feb 2022 11:02:47 +0800 Subject: [PATCH] Move BilinearTensorProduct OP to phi (#39903) * Move BilinearTensorProduct OP to phi * Set dtype for Infermeta --- .../operators/bilinear_tensor_product_op.cc | 154 ++------------- .../operators/bilinear_tensor_product_op.cu | 29 --- .../operators/bilinear_tensor_product_op.h | 181 ------------------ paddle/phi/infermeta/backward.cc | 48 +++++ paddle/phi/infermeta/backward.h | 9 + paddle/phi/infermeta/multiary.cc | 66 +++++++ paddle/phi/infermeta/multiary.h | 7 + .../bilinear_tensor_product_grad_kernel.h | 32 ++++ .../kernels/bilinear_tensor_product_kernel.h | 30 +++ .../bilinear_tensor_product_grad_kernel.cc | 25 +++ .../cpu/bilinear_tensor_product_kernel.cc | 25 +++ .../bilinear_tensor_product_grad_kernel.cu | 25 +++ .../gpu/bilinear_tensor_product_kernel.cu | 25 +++ ...bilinear_tensor_product_grad_kernel_impl.h | 144 ++++++++++++++ .../bilinear_tensor_product_kernel_impl.h | 75 ++++++++ .../ops/compat/bilinear_tensor_product_sig.cc | 41 ++++ 16 files changed, 569 insertions(+), 347 deletions(-) delete mode 100644 paddle/fluid/operators/bilinear_tensor_product_op.cu delete mode 100644 paddle/fluid/operators/bilinear_tensor_product_op.h create mode 100644 paddle/phi/kernels/bilinear_tensor_product_grad_kernel.h create mode 100644 paddle/phi/kernels/bilinear_tensor_product_kernel.h create mode 100644 paddle/phi/kernels/cpu/bilinear_tensor_product_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/bilinear_tensor_product_kernel.cc create mode 100644 paddle/phi/kernels/gpu/bilinear_tensor_product_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/bilinear_tensor_product_kernel.cu create mode 100644 paddle/phi/kernels/impl/bilinear_tensor_product_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/bilinear_tensor_product_kernel_impl.h create mode 100644 paddle/phi/ops/compat/bilinear_tensor_product_sig.cc diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.cc b/paddle/fluid/operators/bilinear_tensor_product_op.cc index 253a96004bd..4774c0a1dbc 100644 --- a/paddle/fluid/operators/bilinear_tensor_product_op.cc +++ b/paddle/fluid/operators/bilinear_tensor_product_op.cc @@ -12,84 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/bilinear_tensor_product_op.h" -#include -#include -#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { -using framework::Tensor; - class BilinearTensorProductOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::InvalidArgument("Input(X) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Y"), true, - platform::errors::InvalidArgument("Input(Y) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Weight"), true, - platform::errors::InvalidArgument("Input(Weight) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument("Output(Out) should not be null.")); - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - auto weight_dims = ctx->GetInputDim("Weight"); - - PADDLE_ENFORCE_EQ( - x_dims.size(), 2UL, - platform::errors::InvalidArgument("The input(X) must be a 2D Tensor.")); - PADDLE_ENFORCE_EQ( - y_dims.size(), 2UL, - platform::errors::InvalidArgument("The input(Y) must be a 2D Tensor.")); - PADDLE_ENFORCE_EQ( - weight_dims.size(), 3UL, - platform::errors::InvalidArgument("Expected the input(Weight) is a 3D " - "tensor. But received %dD tensor.", - weight_dims.size())); - if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0)) { - PADDLE_ENFORCE_EQ( - x_dims[0], y_dims[0], - platform::errors::InvalidArgument( - "The first dimension(batch_size) of input(X) must be " - "equal to the first dimension of the input(Y).")); - } - PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], - platform::errors::InvalidArgument( - "The second dimension of input(X) must be equal to " - "the second dimension of the input(Weight).")); - PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], - platform::errors::InvalidArgument( - "The second dimension of input(Y) must be equal to " - "the third dimension of the input(Weight).")); - - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims.size(), 2UL, - platform::errors::InvalidArgument( - "The Input(Bias) must be a 2-D tensor with " - "the 2nd dimension fixed to 1 (a row vector).")); - PADDLE_ENFORCE_EQ(bias_dims[0], 1UL, - platform::errors::InvalidArgument( - "The Input(Bias) must be a 2-D tensor with " - "the 2nd dimension fixed to 1 (a row vector).")); - PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], - platform::errors::InvalidArgument( - "The second dimension of input(Bias) must be equal " - "to the first dimension of the input(Weight).")); - } - - ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { @@ -125,59 +59,6 @@ Where $W_i$ is the $i$-th slice of Input(Weight); class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::InvalidArgument("Input(X) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Y"), true, - platform::errors::InvalidArgument("Input(Y) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Weight"), true, - platform::errors::InvalidArgument("Input(Weight) should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, - platform::errors::InvalidArgument( - "Input(Out@GRAD) should not be null.")); - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - auto weight_dims = ctx->GetInputDim("Weight"); - auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - - PADDLE_ENFORCE_EQ(out_dims.size(), 2UL, - platform::errors::InvalidArgument( - "The input(Out@GRAD) must be a 2D Tensor.")); - PADDLE_ENFORCE_EQ( - x_dims[0], out_dims[0], - platform::errors::InvalidArgument( - "The first dimension(batch_size) of input(Out@GRAD) must be " - "equal to the first dimension of the Input(X).")); - PADDLE_ENFORCE_EQ( - weight_dims[0], out_dims[1], - platform::errors::InvalidArgument( - "The second dimension of input(Out@GRAD) must be equal to " - "the third dimension of the Input(Weight).")); - - auto bias_grad_name = framework::GradVarName("Bias"); - if (ctx->HasOutput(bias_grad_name)) { - ctx->SetOutputDim(bias_grad_name, {1, out_dims[1]}); - } - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - auto weight_grad_name = framework::GradVarName("Weight"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); - } - if (ctx->HasOutput(weight_grad_name)) { - ctx->SetOutputDim(weight_grad_name, weight_dims); - } - } }; template @@ -208,21 +89,20 @@ class BilinearTensorProductGradOpMaker } // namespace paddle namespace ops = paddle::operators; + +DELCARE_INFER_SHAPE_FUNCTOR(bilinear_tensor_product, + BilinearTensorProductInferShapeFunctor, + PT_INFER_META(phi::BilinearTensorProductInferMeta)); +DELCARE_INFER_SHAPE_FUNCTOR( + bilinear_tensor_product_grad, BilinearTensorProductGradInferShapeFunctor, + PT_INFER_META(phi::BilinearTensorProductGradInferMeta)); + REGISTER_OPERATOR( bilinear_tensor_product, ops::BilinearTensorProductOp, ops::BilinearTensorProductOpMaker, ops::BilinearTensorProductGradOpMaker, - ops::BilinearTensorProductGradOpMaker); + ops::BilinearTensorProductGradOpMaker, + BilinearTensorProductInferShapeFunctor); REGISTER_OPERATOR(bilinear_tensor_product_grad, - ops::BilinearTensorProductOpGrad); -REGISTER_OP_CPU_KERNEL( - bilinear_tensor_product, - ops::BilinearTensorProductKernel, - ops::BilinearTensorProductKernel); -REGISTER_OP_CPU_KERNEL( - bilinear_tensor_product_grad, - ops::BilinearTensorProductGradKernel, - ops::BilinearTensorProductGradKernel); + ops::BilinearTensorProductOpGrad, + BilinearTensorProductGradInferShapeFunctor); diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.cu b/paddle/fluid/operators/bilinear_tensor_product_op.cu deleted file mode 100644 index c2b4f69e685..00000000000 --- a/paddle/fluid/operators/bilinear_tensor_product_op.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/bilinear_tensor_product_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - bilinear_tensor_product, - ops::BilinearTensorProductKernel, - ops::BilinearTensorProductKernel); -REGISTER_OP_CUDA_KERNEL( - bilinear_tensor_product_grad, - ops::BilinearTensorProductGradKernel, - ops::BilinearTensorProductGradKernel); diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.h b/paddle/fluid/operators/bilinear_tensor_product_op.h deleted file mode 100644 index 2dbe3a132d7..00000000000 --- a/paddle/fluid/operators/bilinear_tensor_product_op.h +++ /dev/null @@ -1,181 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -template -using EigenMatrix = framework::EigenMatrix; - -template -class BilinearTensorProductKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* weight = ctx.Input("Weight"); - auto* bias = ctx.Input("Bias"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - auto y_mat = EigenMatrix::From(*y); - auto output_mat = EigenMatrix::From(*out); - - auto batch_size = x->dims()[0]; - auto weight_dims = weight->dims(); - int out_dim = weight_dims[0]; - auto x_dim = weight_dims[1]; - auto y_dim = weight_dims[2]; - auto& place = *ctx.template device_context().eigen_device(); - auto& dev_ctx = ctx.template device_context(); - - // Create the intermediate variable to calculate the result of - // Input(X) multiplied by Input(Weight_i), the formula is: - // left_mul = X Weight_i. - Tensor left_mul; - left_mul.mutable_data(phi::make_ddim({batch_size, y_dim}), - ctx.GetPlace()); - auto left_mul_mat = EigenMatrix::From(left_mul); - - for (int i = 0; i < out_dim; ++i) { - auto output_col_vec = output_mat.chip(i, 1); - Tensor weight_mat = - weight->Slice(i, i + 1).Resize(phi::make_ddim({x_dim, y_dim})); - phi::funcs::GetBlas(dev_ctx).GEMM( - CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data(), - weight_mat.data(), 0, left_mul.data()); - output_col_vec.device(place) = - (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); - } - if (bias) { - auto bias_vec = EigenMatrix::From(*bias); - Eigen::DSizes bcast(batch_size, 1); - output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; - } - } -}; - -template -class BilinearTensorProductGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* x = ctx.Input("X"); - const Tensor* y = ctx.Input("Y"); - const Tensor* weight = ctx.Input("Weight"); - Tensor* d_x = ctx.Output(framework::GradVarName("X")); - Tensor* d_y = ctx.Output(framework::GradVarName("Y")); - Tensor* d_weight = ctx.Output(framework::GradVarName("Weight")); - Tensor* d_bias = ctx.Output(framework::GradVarName("Bias")); - const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); - - auto batch_size = x->dims()[0]; - auto weight_dims = weight->dims(); - int out_dim = weight_dims[0]; - auto x_dim = weight_dims[1]; - auto y_dim = weight_dims[2]; - - auto x_mat = EigenMatrix::From(*x); - auto y_mat = EigenMatrix::From(*y); - auto d_out_mat = EigenMatrix::From(*d_out); - auto& place = *ctx.template device_context().eigen_device(); - auto& dev_ctx = ctx.template device_context(); - // Create the intermediate variable to calculate the Output(Y@Grad). - Tensor x_scale; - x_scale.mutable_data(phi::make_ddim({batch_size, x_dim}), - ctx.GetPlace()); - auto x_scale_mat = EigenMatrix::From(x_scale); - - // Create the intermediate variable to calculate the Output(X@Grad). - Tensor y_scale; - y_scale.mutable_data(phi::make_ddim({batch_size, y_dim}), - ctx.GetPlace()); - auto y_scale_mat = EigenMatrix::From(y_scale); - - phi::funcs::SetConstant set_zero; - - if (d_x) { - d_x->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, d_x, static_cast(0)); - } - - if (d_y) { - d_y->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, d_y, static_cast(0)); - } - - if (d_weight) { - d_weight->mutable_data(ctx.GetPlace()); - } - - auto blas = phi::funcs::GetBlas(ctx); - - // Caculate the Output(X@Grad) and Output(Y@Grad). - if (d_x || d_y || d_weight) { - Eigen::DSizes bcast_for_x(1, y_dim); - Eigen::DSizes bcast_for_y(1, x_dim); - Eigen::DSizes bcast_for_weight(1, x_dim); - - for (int i = 0; i < out_dim; ++i) { - Tensor weight_i = - weight->Slice(i, i + 1).Resize(phi::make_ddim({x_dim, y_dim})); - auto output_vec = d_out_mat.chip(i, 1); - - if (d_x) { - y_scale_mat.device(place) = - output_vec.reshape(Eigen::DSizes(batch_size, 1)) - .broadcast(bcast_for_x) * - y_mat; - blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1, - y_scale.data(), weight_i.data(), 1, d_x->data()); - } - - if (d_y || d_weight) { - auto output_vec_y = - output_vec.reshape(Eigen::DSizes(batch_size, 1)) - .broadcast(bcast_for_y); - x_scale_mat.device(place) = output_vec_y * x_mat; - if (d_y) { - blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, - x_scale.data(), weight_i.data(), 1, d_y->data()); - } - if (d_weight) { - Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( - phi::make_ddim({x_dim, y_dim})); - blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1, - x_scale.data(), y->data(), 0, d_weight_i.data()); - } - } - } - } - - // calculate the gradient of Input(Bias). - if (d_bias) { - d_bias->mutable_data(ctx.GetPlace()); - auto d_bias_mat = framework::EigenVector::Flatten(*d_bias); - d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes(0)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index c4ae2e0b371..e08eae0fc68 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -16,6 +16,54 @@ limitations under the License. */ namespace phi { +void BilinearTensorProductGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& weight, + const MetaTensor& dout, + MetaTensor* dx, + MetaTensor* dy, + MetaTensor* dweight, + MetaTensor* dbias) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto weight_dims = weight.dims(); + auto out_dims = dout.dims(); + + PADDLE_ENFORCE_EQ( + out_dims.size(), + 2UL, + errors::InvalidArgument("The input(Out@GRAD) must be a 2D Tensor.")); + PADDLE_ENFORCE_EQ( + x_dims[0], + out_dims[0], + errors::InvalidArgument( + "The first dimension(batch_size) of input(Out@GRAD) must be " + "equal to the first dimension of the Input(X).")); + PADDLE_ENFORCE_EQ( + weight_dims[0], + out_dims[1], + errors::InvalidArgument( + "The second dimension of input(Out@GRAD) must be equal to " + "the third dimension of the Input(Weight).")); + + if (dx) { + dx->set_dims(x_dims); + dx->set_dtype(x.dtype()); + } + if (dy) { + dy->set_dims(y_dims); + dy->set_dtype(y.dtype()); + } + if (dweight) { + dweight->set_dims(weight_dims); + dweight->set_dtype(weight.dtype()); + } + if (dbias) { + dbias->set_dims({1, out_dims[1]}); + dbias->set_dtype(dout.dtype()); + } +} + void GeneralBinaryGradInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* dx, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 965c380db25..35f988bbc0b 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -20,6 +20,15 @@ limitations under the License. */ namespace phi { +void BilinearTensorProductGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& weight, + const MetaTensor& dout, + MetaTensor* dx, + MetaTensor* dy, + MetaTensor* dweight, + MetaTensor* dbias); + void GeneralBinaryGradInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* dx, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index d72033f9528..7a0db3d5c17 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -18,6 +18,72 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/concat_funcs.h" namespace phi { +void BilinearTensorProductInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& weight, + paddle::optional bias, + MetaTensor* out, + MetaConfig config) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto weight_dims = weight.dims(); + + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2UL, + errors::InvalidArgument("The input(X) must be a 2D Tensor.")); + PADDLE_ENFORCE_EQ( + y_dims.size(), + 2UL, + errors::InvalidArgument("The input(Y) must be a 2D Tensor.")); + PADDLE_ENFORCE_EQ( + weight_dims.size(), + 3UL, + errors::InvalidArgument( + "Expected the input(Weight) is a 3D tensor. But received %dD tensor.", + weight_dims.size())); + if (config.is_runtime || (x_dims[0] > 0 && y_dims[0] > 0)) { + PADDLE_ENFORCE_EQ(x_dims[0], + y_dims[0], + errors::InvalidArgument( + "The first dimension(batch_size) of input(X) must be " + "equal to the first dimension of the input(Y).")); + } + PADDLE_ENFORCE_EQ(x_dims[1], + weight_dims[1], + errors::InvalidArgument( + "The second dimension of input(X) must be equal to " + "the second dimension of the input(Weight).")); + PADDLE_ENFORCE_EQ(y_dims[1], + weight_dims[2], + errors::InvalidArgument( + "The second dimension of input(Y) must be equal to " + "the third dimension of the input(Weight).")); + + if (bias.get_ptr()) { + auto bias_dims = bias->dims(); + PADDLE_ENFORCE_EQ(bias_dims.size(), + 2UL, + errors::InvalidArgument( + "The Input(Bias) must be a 2-D tensor with " + "the 2nd dimension fixed to 1 (a row vector).")); + PADDLE_ENFORCE_EQ(bias_dims[0], + 1UL, + errors::InvalidArgument( + "The Input(Bias) must be a 2-D tensor with " + "the 2nd dimension fixed to 1 (a row vector).")); + PADDLE_ENFORCE_EQ(bias_dims[1], + weight_dims[0], + errors::InvalidArgument( + "The second dimension of input(Bias) must be equal " + "to the first dimension of the input(Weight).")); + } + + out->set_dims({x_dims[0], weight_dims[0]}); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + void ConcatInferMeta(const std::vector& x, const Scalar& axis_scalar, MetaTensor* out, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 589fc33333d..a5fb2a4cbdd 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -18,6 +18,13 @@ limitations under the License. */ #include "paddle/phi/core/meta_tensor.h" namespace phi { +void BilinearTensorProductInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& weight, + paddle::optional bias, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ConcatInferMeta(const std::vector& x, const Scalar& axis_scalar, MetaTensor* out, diff --git a/paddle/phi/kernels/bilinear_tensor_product_grad_kernel.h b/paddle/phi/kernels/bilinear_tensor_product_grad_kernel.h new file mode 100644 index 00000000000..499aa1e0b2e --- /dev/null +++ b/paddle/phi/kernels/bilinear_tensor_product_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BilinearTensorProductGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* dweight, + DenseTensor* dbias); + +} // namespace phi diff --git a/paddle/phi/kernels/bilinear_tensor_product_kernel.h b/paddle/phi/kernels/bilinear_tensor_product_kernel.h new file mode 100644 index 00000000000..b34e8946ddd --- /dev/null +++ b/paddle/phi/kernels/bilinear_tensor_product_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void BilinearTensorProductKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + paddle::optional bias, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/bilinear_tensor_product_grad_kernel.cc b/paddle/phi/kernels/cpu/bilinear_tensor_product_grad_kernel.cc new file mode 100644 index 00000000000..2268212316a --- /dev/null +++ b/paddle/phi/kernels/cpu/bilinear_tensor_product_grad_kernel.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bilinear_tensor_product_grad_kernel.h" +#include "paddle/phi/kernels/impl/bilinear_tensor_product_grad_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(bilinear_tensor_product_grad, + CPU, + ALL_LAYOUT, + phi::BilinearTensorProductGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/bilinear_tensor_product_kernel.cc b/paddle/phi/kernels/cpu/bilinear_tensor_product_kernel.cc new file mode 100644 index 00000000000..25bc5913865 --- /dev/null +++ b/paddle/phi/kernels/cpu/bilinear_tensor_product_kernel.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bilinear_tensor_product_kernel.h" +#include "paddle/phi/kernels/impl/bilinear_tensor_product_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(bilinear_tensor_product, + CPU, + ALL_LAYOUT, + phi::BilinearTensorProductKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/bilinear_tensor_product_grad_kernel.cu b/paddle/phi/kernels/gpu/bilinear_tensor_product_grad_kernel.cu new file mode 100644 index 00000000000..f4f69ee83ee --- /dev/null +++ b/paddle/phi/kernels/gpu/bilinear_tensor_product_grad_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bilinear_tensor_product_grad_kernel.h" +#include "paddle/phi/kernels/impl/bilinear_tensor_product_grad_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(bilinear_tensor_product_grad, + GPU, + ALL_LAYOUT, + phi::BilinearTensorProductGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/bilinear_tensor_product_kernel.cu b/paddle/phi/kernels/gpu/bilinear_tensor_product_kernel.cu new file mode 100644 index 00000000000..b81b842cedb --- /dev/null +++ b/paddle/phi/kernels/gpu/bilinear_tensor_product_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bilinear_tensor_product_kernel.h" +#include "paddle/phi/kernels/impl/bilinear_tensor_product_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(bilinear_tensor_product, + GPU, + ALL_LAYOUT, + phi::BilinearTensorProductKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/bilinear_tensor_product_grad_kernel_impl.h b/paddle/phi/kernels/impl/bilinear_tensor_product_grad_kernel_impl.h new file mode 100644 index 00000000000..c199833b42a --- /dev/null +++ b/paddle/phi/kernels/impl/bilinear_tensor_product_grad_kernel_impl.h @@ -0,0 +1,144 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void BilinearTensorProductGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* dweight, + DenseTensor* dbias) { + auto batch_size = x.dims()[0]; + auto weight_dims = weight.dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + + auto x_mat = EigenMatrix::From(x); + auto y_mat = EigenMatrix::From(y); + auto dout_mat = EigenMatrix::From(dout); + auto& place = *ctx.eigen_device(); + // Create the intermediate variable to calculate the Output(Y@Grad). + DenseTensor x_scale; + x_scale.Resize(make_ddim({batch_size, x_dim})); + ctx.template Alloc(&x_scale); + auto x_scale_mat = EigenMatrix::From(x_scale); + + // Create the intermediate variable to calculate the Output(X@Grad). + DenseTensor y_scale; + y_scale.Resize(make_ddim({batch_size, y_dim})); + ctx.template Alloc(&y_scale); + auto y_scale_mat = EigenMatrix::From(y_scale); + + funcs::SetConstant set_zero; + + if (dx) { + ctx.template Alloc(dx); + set_zero(ctx, dx, static_cast(0)); + } + + if (dy) { + ctx.template Alloc(dy); + set_zero(ctx, dy, static_cast(0)); + } + + if (dweight) { + ctx.template Alloc(dweight); + } + + auto blas = funcs::GetBlas(ctx); + + // Caculate the Output(X@Grad) and Output(Y@Grad). + if (dx || dy || dweight) { + Eigen::DSizes bcast_for_x(1, y_dim); + Eigen::DSizes bcast_for_y(1, x_dim); + Eigen::DSizes bcast_for_weight(1, x_dim); + + for (int i = 0; i < out_dim; ++i) { + DenseTensor weight_i = + weight.Slice(i, i + 1).Resize(make_ddim({x_dim, y_dim})); + auto output_vec = dout_mat.chip(i, 1); + + if (dx) { + y_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_x) * + y_mat; + blas.GEMM(CblasNoTrans, + CblasTrans, + batch_size, + x_dim, + y_dim, + 1, + y_scale.data(), + weight_i.data(), + 1, + dx->data()); + } + + if (dy || dweight) { + auto output_vec_y = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_y); + x_scale_mat.device(place) = output_vec_y * x_mat; + if (dy) { + blas.GEMM(CblasNoTrans, + CblasNoTrans, + batch_size, + y_dim, + x_dim, + 1, + x_scale.data(), + weight_i.data(), + 1, + dy->data()); + } + if (dweight) { + DenseTensor dweight_i = + dweight->Slice(i, i + 1).Resize(make_ddim({x_dim, y_dim})); + blas.GEMM(CblasTrans, + CblasNoTrans, + x_dim, + y_dim, + batch_size, + 1, + x_scale.data(), + y.data(), + 0, + dweight_i.data()); + } + } + } + } + + // calculate the gradient of Input(Bias). + if (dbias) { + ctx.template Alloc(dbias); + auto dbias_mat = EigenVector::Flatten(*dbias); + dbias_mat.device(place) = dout_mat.sum(Eigen::DSizes(0)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/bilinear_tensor_product_kernel_impl.h b/paddle/phi/kernels/impl/bilinear_tensor_product_kernel_impl.h new file mode 100644 index 00000000000..3f30a4b958e --- /dev/null +++ b/paddle/phi/kernels/impl/bilinear_tensor_product_kernel_impl.h @@ -0,0 +1,75 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void BilinearTensorProductKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + paddle::optional bias, + DenseTensor* out) { + ctx.template Alloc(out); + + auto y_mat = EigenMatrix::From(y); + auto output_mat = EigenMatrix::From(*out); + + auto batch_size = x.dims()[0]; + auto weight_dims = weight.dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + auto& place = *ctx.eigen_device(); + + // Create the intermediate variable to calculate the result of + // Input(X) multiplied by Input(Weight_i), the formula is: + // left_mul = X Weight_i. + DenseTensor left_mul; + left_mul.Resize(phi::make_ddim({batch_size, y_dim})); + ctx.template Alloc(&left_mul); + auto left_mul_mat = EigenMatrix::From(left_mul); + + for (int i = 0; i < out_dim; ++i) { + auto output_col_vec = output_mat.chip(i, 1); + DenseTensor weight_mat = + weight.Slice(i, i + 1).Resize(phi::make_ddim({x_dim, y_dim})); + phi::funcs::GetBlas(ctx).GEMM(CblasNoTrans, + CblasNoTrans, + batch_size, + y_dim, + x_dim, + 1, + x.data(), + weight_mat.data(), + 0, + left_mul.data()); + output_col_vec.device(place) = + (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + } + if (bias.get_ptr()) { + auto bias_vec = EigenMatrix::From(*(bias.get_ptr())); + Eigen::DSizes bcast(batch_size, 1); + output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/bilinear_tensor_product_sig.cc b/paddle/phi/ops/compat/bilinear_tensor_product_sig.cc new file mode 100644 index 00000000000..570bf7ce943 --- /dev/null +++ b/paddle/phi/ops/compat/bilinear_tensor_product_sig.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature BilinearTensorProductOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}, {}, {"Out"}); +} + +KernelSignature BilinearTensorProductGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("bilinear_tensor_product_grad", + {"X", "Y", "Weight", GradVarName("Out")}, + {}, + {GradVarName("X"), + GradVarName("Y"), + GradVarName("Weight"), + GradVarName("Bias")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(bilinear_tensor_product, + phi::BilinearTensorProductOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(bilinear_tensor_product_grad, + phi::BilinearTensorProductGradOpArgumentMapping); -- GitLab