diff --git a/paddle/fluid/operators/math/blas.cc b/paddle/fluid/operators/math/blas.cc index 3e09ef7a249c789be7028bec193096086fd1296f..6a143b3c056455595fdedc131b0c5f4ee756e1e0 100644 --- a/paddle/fluid/operators/math/blas.cc +++ b/paddle/fluid/operators/math/blas.cc @@ -18,34 +18,26 @@ namespace paddle { namespace operators { namespace math { -MatDescriptor GetMatDim(const framework::DDim& dim, int num_flatten_cols, - bool trans) { +MatDescriptor CreateMatrixDescriptor(const framework::DDim &tensor_dim, + int num_flatten_cols, bool trans) { + PADDLE_ENFORCE_GT(tensor_dim.size(), 1); MatDescriptor retv; if (num_flatten_cols > 1) { - auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols); + auto flatten_dim = framework::flatten_to_2d(tensor_dim, num_flatten_cols); retv.height_ = flatten_dim[0]; retv.width_ = flatten_dim[1]; } else { - if (dim.size() == 1) { - retv.height_ = 1; - retv.width_ = dim[0]; - } else if (dim.size() == 2) { - retv.height_ = dim[0]; - retv.width_ = dim[1]; + if (tensor_dim.size() == 2) { + retv.height_ = tensor_dim[0]; + retv.width_ = tensor_dim[1]; } else { - if (dim.size() == 3) { - retv.batch_size_ = dim[0]; - retv.height_ = dim[1]; - retv.width_ = dim[2]; - } else { - auto dim_vec = framework::vectorize(dim); - retv.batch_size_ = 1; - for (size_t i = 0; i < dim_vec.size() - 2; ++i) { - retv.batch_size_ *= dim_vec[i]; - retv.height_ = dim_vec[dim_vec.size() - 2]; - retv.width_ = dim_vec[dim_vec.size() - 1]; - } + auto dim_vec = framework::vectorize(tensor_dim); + retv.batch_size_ = 1; + for (size_t i = 0; i < dim_vec.size() - 2; ++i) { + retv.batch_size_ *= dim_vec[i]; } + retv.height_ = dim_vec[dim_vec.size() - 2]; + retv.width_ = dim_vec[dim_vec.size() - 1]; retv.stride_ = retv.height_ * retv.width_; } } diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 0c0794125a0f6d1a87f689315e882a43ac0feb22..dabde43850db770d286b13cacd32bee181328d5c 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -46,6 +46,25 @@ namespace paddle { namespace operators { namespace math { +/** + * Matrix Descriptor of a memory buffer. + * + * It is used for Blas::MatMul. MatMul operator can be batched. + * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a + * `batch_size` times of GEMM. The batched GEMM could be faster base on the + * implementation of the blas library. The batch size could be zero. If any + * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g., + * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be + * [BatchSize, H1, W2] + * + * The boolean flag, `trans`, describe the memory is the transpose of matrix or + * not. If the trans is true, the last two dims of matrix are transposed. The + * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height]. + * + * The MatDescriptor is not only the dimension or shape of a matrix, it also + * contains the layout, stride of matrix. It is clearer to have a structure than + * reuse `DDim`. + */ struct MatDescriptor { int64_t height_; int64_t width_; @@ -54,8 +73,22 @@ struct MatDescriptor { bool trans_; }; -extern MatDescriptor GetMatDim(const framework::DDim& tensor, - int num_flatten_cols, bool trans); +/** + * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose + * flag + * + * @param tensor_dim: The dimension of the tensor. The rank of this dimension + * must larger than 1. + * + * @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first + * dimension(column length) will be the product of tensor's first `num_col_dims` + * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the + * batch_size of descriptor. + * + * @param trans: True if the matrix is transposed. + */ +extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim, + int num_flatten_cols, bool trans); template class Blas { diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index c285d461e856191b9016fea53f10f694db99150e..da21b8ad7d4e353e1dbe98fde1fbac1b0d37fd5d 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -12,14 +12,257 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/matmul_op.h" #include +#include #include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { +/** + * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the + * original x_dim is returned. + */ +static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return framework::make_ddim({1, x_dim[0]}); +} + +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return framework::make_ddim({y_dim[0], 1}); +} + +template +class MatMulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& x = + detail::Ref(context.Input("X"), "Cannot find X"); + auto& y = + detail::Ref(context.Input("Y"), "Cannot find Y"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor( + RowMatrixFromVector(x.dims()), 0, context.Attr("transpose_X")); + auto mat_dim_b = math::CreateMatrixDescriptor( + ColumnMatrixFromVector(y.dims()), 0, context.Attr("transpose_Y")); + blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); + } +}; + +// Reshape a rank-3 tensor from P x M x N to (P * M) x N. +// Identity op if the tensor is not of rank 3. +static framework::Tensor FoldInitDims(const framework::Tensor& input) { + auto output = input; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + } + return output; +} + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, + const framework::Tensor& input) { + auto in_dims = input.dims(); + if (in_dims.size() != 3) { + return input; + } + framework::Tensor output; + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); + output.mutable_data(context.GetPlace()); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(context, input, &output, axis); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); -using framework::Tensor; + return output; +} + +/** + * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. + * + * The shape would be [BatchSize, H, W] or [H, W]. + * If transposed, `H,W` will be swapped. + */ +static void ReshapeTensorIntoMatrixSequence( + framework::Tensor* x, const math::MatDescriptor& descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +/** + * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor + * Out = matmul(x, y) + * + * This method will first calculate X,Y matrix sequence, and then calculate + * the out shape. + * + * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] + * The out = [BatchSize, H1, W2] + * + * If there is no batch size in `X` and `Y`, the out will be [H1, W2] + * If any of `X` and `Y` has batch size BatchSize, the out will have the + * BatchSize. + */ +static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, + framework::Tensor* y, + framework::Tensor* out, bool trans_x, + bool trans_y) { + auto x_dim = RowMatrixFromVector(x->dims()); + auto y_dim = ColumnMatrixFromVector(y->dims()); + auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, mat_dim_y.width_}); + } + + ReshapeTensorIntoMatrixSequence(x, mat_dim_x); + ReshapeTensorIntoMatrixSequence(y, mat_dim_y); +} + +// Using dimensional constraints on matrix multiplication, it is +// straight-forward to check the following table for when X and Y +// are both matrices. +// +// transpose_X | False | True | False | True +// transpose_Y | False | False | True | True +// -----------+----------+----------+----------+----------- +// dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T +// dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T +// +// When X is a vector of size K, we treat it instead as a matrix of shape +// (1, K). Similarly, when Y is a vector of size K, we treat it instead as +// a matrix of shape (K, 1). +// +// When X and Y are both 3-dimensional tensors, then the first dimension +// the batch dimension can be ignored and the exact same formulas apply +// as for two matrices. +// +// Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end +// up with formulas like +// +// dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj} +// +// To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N +// to X: (P * M) x K, dOut: (P * M) x N. +template +class MatMulGradKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, + framework::Tensor* out) const { + out->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); + } + + void CalcInputGrad(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor& b, + bool trans_b, bool is_fold_init_dims_b, + framework::Tensor* out) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, out); + } else { + auto& ctx = context.template device_context(); + MatMul(context, is_fold_init_dims_a + ? FoldInitDims(a) + : FoldHeadAndLastDims(ctx, a), + trans_a, is_fold_init_dims_b + ? FoldInitDims(b) + : FoldHeadAndLastDims(ctx, b), + trans_b, out); + } + } + + void Compute(const framework::ExecutionContext& context) const override { + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = + *context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + auto* dy = context.Output(framework::GradVarName("Y")); + bool transpose_x = context.Attr("transpose_X"); + bool transpose_y = context.Attr("transpose_Y"); + + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + + if (transpose_x && transpose_y) { + CalcInputGrad(context, y, true, true, dout, true, false, dx); + CalcInputGrad(context, dout, true, true, x, true, false, dy); + } else if (transpose_x) { + CalcInputGrad(context, y, false, false, dout, true, false, dx); + CalcInputGrad(context, x, false, false, dout, false, true, dy); + } else if (transpose_y) { + CalcInputGrad(context, dout, false, false, y, false, true, dx); + CalcInputGrad(context, dout, true, true, x, false, true, dy); + } else { + CalcInputGrad(context, dout, false, false, y, true, false, dx); + CalcInputGrad(context, x, true, true, dout, false, true, dy); + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } + } + } +}; class MatMulOp : public framework::OperatorWithKernel { public: @@ -37,9 +280,11 @@ class MatMulOp : public framework::OperatorWithKernel { auto dim_x = context->GetInputDim("X"); auto dim_y = context->GetInputDim("Y"); - auto mat_dim_x = math::GetMatDim(GetXDim(dim_x), 0, + auto mat_dim_x = + math::CreateMatrixDescriptor(RowMatrixFromVector(dim_x), 0, context->Attrs().Get("transpose_X")); - auto mat_dim_y = math::GetMatDim(GetYDim(dim_y), 0, + auto mat_dim_y = + math::CreateMatrixDescriptor(ColumnMatrixFromVector(dim_y), 0, context->Attrs().Get("transpose_Y")); PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); @@ -151,15 +396,40 @@ class MatMulOpGrad : public framework::OperatorWithKernel { } }; +class MatMulOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* retv = new framework::OpDesc(); + retv->SetType("matmul_grad"); + retv->SetInput("X", Input("X")); + retv->SetInput("Y", Input("Y")); + retv->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), InputGrad("X")); + retv->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); + retv->SetAttrMap(Attrs()); + return std::unique_ptr(retv); + } +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::MatMulOpGradMaker); REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad); REGISTER_OP_CPU_KERNEL( matmul, ops::MatMulKernel); REGISTER_OP_CPU_KERNEL( matmul_grad, ops::MatMulGradKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL( + matmul, ops::MatMulKernel); +REGISTER_OP_CUDA_KERNEL( + matmul_grad, + ops::MatMulGradKernel); +#endif diff --git a/paddle/fluid/operators/matmul_op.cu.cc b/paddle/fluid/operators/matmul_op.cu.cc deleted file mode 100644 index e021bbe645399e410cde5c3ff7035d4d68c71744..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/matmul_op.cu.cc +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/matmul_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - matmul, ops::MatMulKernel); -REGISTER_OP_CUDA_KERNEL( - matmul_grad, - ops::MatMulGradKernel); diff --git a/paddle/fluid/operators/matmul_op.h b/paddle/fluid/operators/matmul_op.h deleted file mode 100644 index 9bf39026ff3f4202af141b7f9926c7a30e61d133..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/matmul_op.h +++ /dev/null @@ -1,242 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/math_function.h" - -namespace paddle { -namespace operators { -inline framework::DDim GetXDim(const framework::DDim& x_dim) { - if (x_dim.size() > 1) { - return x_dim; - } - return framework::make_ddim({1, x_dim[0]}); -} - -inline framework::DDim GetYDim(const framework::DDim& y_dim) { - if (y_dim.size() > 1) { - return y_dim; - } - return framework::make_ddim({y_dim[0], 1}); -} - -template -class MatMulKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto& x = - detail::Ref(context.Input("X"), "Cannot find X"); - auto& y = - detail::Ref(context.Input("Y"), "Cannot find Y"); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - - auto blas = math::GetBlas(context); - auto mat_dim_a = math::GetMatDim(GetXDim(x.dims()), 0, - context.Attr("transpose_X")); - auto mat_dim_b = math::GetMatDim(GetYDim(y.dims()), 0, - context.Attr("transpose_Y")); - blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); - } -}; - -// Reshape a rank-3 tensor from P x M x N to (P * M) x N. -// Identity op if the tensor is not of rank 3. -inline framework::Tensor CombineBatchAndM(const framework::Tensor& input) { - auto output = input; - auto in_dims = input.dims(); - if (in_dims.size() == 3) { - output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); - } - return output; -} - -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -inline framework::Tensor CombineBatchAndN(const DeviceContext& context, - const framework::Tensor& input) { - auto in_dims = input.dims(); - if (in_dims.size() != 3) { - return input; - } - framework::Tensor output; - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - output.mutable_data(context.GetPlace()); - std::vector axis = {1, 0, 2}; - math::Transpose trans; - trans(context, input, &output, axis); - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - - return output; -} - -inline void NormalizeTensorShape(framework::Tensor* x, - const math::MatDescriptor& mat_dim_x) { - int64_t h, w; - h = mat_dim_x.height_; - w = mat_dim_x.width_; - if (mat_dim_x.trans_) { - std::swap(w, h); - } - if (mat_dim_x.batch_size_) { - x->Resize({mat_dim_x.batch_size_, h, w}); - } else { - x->Resize({h, w}); - } -} - -inline void NormalizeXYOutTensorShape(framework::Tensor* x, - framework::Tensor* y, - framework::Tensor* out, bool trans_a, - bool trans_b) { - auto x_dim = GetXDim(x->dims()); - auto y_dim = GetYDim(y->dims()); - auto mat_dim_x = math::GetMatDim(x_dim, 0, trans_a); - auto mat_dim_y = math::GetMatDim(y_dim, 0, trans_b); - if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { - out->Resize({mat_dim_x.height_, mat_dim_y.width_}); - } else { - out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), - mat_dim_x.height_, mat_dim_y.width_}); - } - - NormalizeTensorShape(x, mat_dim_x); - NormalizeTensorShape(y, mat_dim_y); -} - -// Using dimensional constraints on matrix multiplication, it is -// straight-forward to check the following table for when X and Y -// are both matrices. -// -// transpose_X | False | True | False | True -// transpose_Y | False | False | True | True -// -----------+----------+----------+----------+----------- -// dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T -// dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T -// -// When X is a vector of size K, we treat it instead as a matrix of shape -// (1, K). Similarly, when Y is a vector of size K, we treat it instead as -// a matrix of shape (K, 1). -// -// When X and Y are both 3-dimensional tensors, then the first dimension -// the batch dimension can be ignored and the exact same formulas apply -// as for two matrices. -// -// Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end -// up with formulas like -// -// dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj} -// -// To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N -// to X: (P * M) x K, dOut: (P * M) x N. -template -class MatMulGradKernel : public framework::OpKernel { - public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, - framework::Tensor* out) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::GetMatDim(a.dims(), 0, trans_a); - auto mat_dim_b = math::GetMatDim(b.dims(), 0, trans_b); - blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_combine_m_a, const framework::Tensor& b, - bool trans_b, bool is_combine_m_b, - framework::Tensor* out) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out); - } else { - auto& ctx = context.template device_context(); - MatMul( - context, is_combine_m_a ? CombineBatchAndM(a) - : CombineBatchAndN(ctx, a), - trans_a, is_combine_m_b ? CombineBatchAndM(b) - : CombineBatchAndN(ctx, b), - trans_b, out); - } - } - - void Compute(const framework::ExecutionContext& context) const override { - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = - *context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - auto* dy = context.Output(framework::GradVarName("Y")); - bool transpose_x = context.Attr("transpose_X"); - bool transpose_y = context.Attr("transpose_Y"); - - NormalizeXYOutTensorShape(&x, &y, &dout, transpose_x, transpose_y); - framework::DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - } - - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - if (transpose_x && transpose_y) { - CalcInputGrad(context, y, true, true, dout, true, false, dx); - CalcInputGrad(context, dout, true, true, x, true, false, dy); - } else if (transpose_x && !transpose_y) { - CalcInputGrad(context, y, false, false, dout, true, false, dx); - CalcInputGrad(context, x, false, false, dout, false, true, dy); - } else if (!transpose_x && transpose_y) { - CalcInputGrad(context, dout, false, false, y, false, true, dx); - CalcInputGrad(context, dout, true, true, x, false, true, dy); - } else { - CalcInputGrad(context, dout, false, false, y, true, false, dx); - CalcInputGrad(context, x, true, true, dout, false, true, dy); - } - - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - } -}; - -} // namespace operators -} // namespace paddle