From c6a6d87f9693e4ce6296134a89a2ec2a7f596eda Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 7 May 2018 16:21:55 +0800 Subject: [PATCH] Rewrite Matmul, make code cleaner --- paddle/fluid/operators/math/blas.cc | 39 ++- paddle/fluid/operators/math/blas.h | 33 ++ paddle/fluid/operators/math/matmul.h | 149 --------- paddle/fluid/operators/matmul_op.cc | 134 ++------ paddle/fluid/operators/matmul_op.h | 286 +++++++++--------- .../fluid/tests/unittests/test_matmul_op.py | 35 ++- 6 files changed, 258 insertions(+), 418 deletions(-) delete mode 100644 paddle/fluid/operators/math/matmul.h diff --git a/paddle/fluid/operators/math/blas.cc b/paddle/fluid/operators/math/blas.cc index 3eeb77546b9..7427ceac60a 100644 --- a/paddle/fluid/operators/math/blas.cc +++ b/paddle/fluid/operators/math/blas.cc @@ -13,10 +13,47 @@ // limitations under the License. #include "paddle/fluid/operators/math/blas.h" + +#include namespace paddle { namespace operators { namespace math { -// Do nothing. Blas is a header only library. +MatDim GetMatDim(const framework::DDim& dim, int num_flatten_cols, bool trans) { + MatDim retv; + if (num_flatten_cols > 1) { + auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols); + retv.height_ = flatten_dim[0]; + retv.width_ = flatten_dim[1]; + } else { + if (dim.size() == 1) { + retv.height_ = 1; + retv.width_ = dim[0]; + } else if (dim.size() == 2) { + retv.height_ = dim[0]; + retv.width_ = dim[1]; + } else { + if (dim.size() == 3) { + retv.batch_size_ = dim[0]; + retv.height_ = dim[1]; + retv.width_ = dim[2]; + } else { + auto dim_vec = framework::vectorize(dim); + retv.batch_size_ = 1; + for (size_t i = 0; i < dim_vec.size() - 2; ++i) { + retv.batch_size_ *= dim_vec[i]; + retv.height_ = dim_vec[dim_vec.size() - 2]; + retv.width_ = dim_vec[dim_vec.size() - 1]; + } + } + retv.stride_ = retv.height_ * retv.width_; + } + } + if (trans) { + std::swap(retv.width_, retv.height_); + } + retv.trans_ = trans; + return retv; +} } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 5cd2f855d11..cca967f33f8 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -46,6 +46,17 @@ namespace paddle { namespace operators { namespace math { +struct MatDim { + int64_t height_; + int64_t width_; + int64_t stride_{0}; + int64_t batch_size_{0}; + bool trans_; +}; + +extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols, + bool trans); + template class Blas { public: @@ -90,6 +101,28 @@ class Blas { int K, T alpha, const T* A, const T* B, T beta, T* C, int batchCount, int64_t strideA, int64_t strideB) const; + template + void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a, + const framework::Tensor& mat_b, const MatDim& dim_b, T alpha, + framework::Tensor* mat_out, T beta) const { + PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_); + CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; + if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { + this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, + dim_a.width_, alpha, mat_a.data(), + mat_b.data(), beta, mat_out->data()); + } else { + PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || + dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); + this->template BatchedGEMM( + transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, + mat_a.data(), mat_b.data(), beta, mat_out->data(), + dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, + dim_a.stride_, dim_b.stride_); + } + } + private: const DeviceContext& context_; }; diff --git a/paddle/fluid/operators/math/matmul.h b/paddle/fluid/operators/math/matmul.h deleted file mode 100644 index 87fd38a324e..00000000000 --- a/paddle/fluid/operators/math/matmul.h +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/operators/math/blas.h" - -namespace paddle { -namespace operators { -namespace math { - -// Implements the logic of numpy matmul: -// https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html -// -// but allowing also for a, b to be transposed -// -// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported -// yet. -template -class MatMulFunctor { - public: - void operator()(const DeviceContext& context, const framework::Tensor& a, - bool trans_a, const framework::Tensor& b, bool trans_b, - T alpha, framework::Tensor* out, T beta) { - auto dim_a = a.dims(); - auto dim_b = b.dims(); - - PADDLE_ENFORCE(a.place() == b.place() && b.place() == out->place(), - "Tensors must all be in the same place."); - PADDLE_ENFORCE_GE(dim_a.size(), 1, - "Input tensor a must be at least 1-dimensional."); - PADDLE_ENFORCE_GE(dim_b.size(), 1, - "Input tensor b must be at least 1-dimensional."); - - std::vector out_dim; - int64_t batch_count = 1; - if (dim_a.size() > 3) { - PADDLE_ENFORCE(dim_b.size() == dim_a.size(), - "The dimensions of X and Y must be the same, and both of " - "them should be %d-dimensional.", - dim_b.size()); - // The first rank-2 dimensions are accumulated on the batch_count, and the - // last two dimensions are used for matrix multiplication. - for (int j = 0; j < dim_a.size() - 2; ++j) { - PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j], - "The %d-th dimension of X and Y must be the same.", - j); - out_dim.push_back(dim_a[j]); - batch_count *= dim_a[j]; - } - } - - int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0, - strideA = 0, strideB = 0; - - switch (dim_a.size()) { - case 1: - // similar to np.matmul: - // prepend dimension 1 (no transpose) or append dimension 1 (transpose) - M = trans_a ? dim_a[0] : 1; - kA = trans_a ? 1 : dim_a[0]; - break; - case 2: - M = trans_a ? dim_a[1] : dim_a[0]; - kA = trans_a ? dim_a[0] : dim_a[1]; - break; - case 3: - batchCountA = dim_a[0]; - M = trans_a ? dim_a[2] : dim_a[1]; - kA = trans_a ? dim_a[1] : dim_a[2]; - strideA = M * kA; - break; - default: - batchCountA = batch_count; - size_t mat_s = dim_a.size() - 2; - M = trans_a ? dim_a[mat_s + 1] : dim_a[mat_s]; - kA = trans_a ? dim_a[mat_s] : dim_a[mat_s + 1]; - strideA = M * kA; - } - - switch (dim_b.size()) { - case 1: - // similar to np.matmul: - // append dimension 1 (no transpose) or prepend dimension 1 (transpose) - kB = trans_b ? 1 : dim_b[0]; - N = trans_b ? dim_b[0] : 1; - break; - case 2: - kB = trans_b ? dim_b[1] : dim_b[0]; - N = trans_b ? dim_b[0] : dim_b[1]; - break; - case 3: - batchCountB = dim_b[0]; - kB = trans_b ? dim_b[2] : dim_b[1]; - N = trans_b ? dim_b[1] : dim_b[2]; - strideB = kB * N; - break; - default: - batchCountB = batch_count; - size_t mat_s = dim_b.size() - 2; - kB = trans_b ? dim_b[mat_s + 1] : dim_b[mat_s]; - N = trans_b ? dim_b[mat_s] : dim_b[mat_s + 1]; - strideB = kB * N; - } - - PADDLE_ENFORCE_EQ( - kA, kB, - "First matrix's width must be equal with second matrix's height."); - if (batchCountA && batchCountB) { - PADDLE_ENFORCE_EQ( - batchCountA, batchCountB, - "When input tensors a and b are both batched, they must have the " - "same batch dimension."); - } - int batchCount = std::max(batchCountA, batchCountB); - - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - - auto blas = GetBlas(context); - - if (!batchCount) { - // regular matrix multiplication - blas.GEMM(transA, transB, M, N, kA, alpha, a.data(), b.data(), beta, - out->data()); - } else { - // batched matrix multiplication - blas.BatchedGEMM(transA, transB, M, N, kA, alpha, a.data(), - b.data(), beta, out->data(), batchCount, strideA, - strideB); - } - } -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index e5d33fbc364..c285d461e85 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -36,121 +36,39 @@ class MatMulOp : public framework::OperatorWithKernel { auto dim_x = context->GetInputDim("X"); auto dim_y = context->GetInputDim("Y"); - bool transpose_x = context->Attrs().Get("transpose_X"); - bool transpose_y = context->Attrs().Get("transpose_Y"); - - PADDLE_ENFORCE_GE(dim_x.size(), 1, - "Input tensor X must be at least 1-dimensional."); - PADDLE_ENFORCE_GE(dim_y.size(), 1, - "Input tensor Y must be at least 1-dimensional."); - - std::vector out_dim; - int64_t batch_count = 1; - if (dim_x.size() > 3) { - PADDLE_ENFORCE_EQ( - dim_y.size(), dim_x.size(), - "The dimensions of X and Y must be the same, and both of " - "them should be %d-dimensional.", - dim_x.size()); - - // The first rank-2 dimensions are accumulated on the batch_count, and the - // last two dimensions are used for matrix multiplication. - for (int j = 0; j < dim_x.size() - 2; ++j) { - PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j], - "The %d-th dimension of X and Y must be the same.", - j); - out_dim.push_back(dim_x[j]); - batch_count *= dim_x[j]; - } - } - int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; - bool remove_initial_dim = false, remove_final_dim = false; - - switch (dim_x.size()) { - case 1: - if (transpose_x) { - M = dim_x[0]; - KX = 1; - } else { - M = 1; - KX = dim_x[0]; - remove_initial_dim = true; - } - break; - case 2: - M = transpose_x ? dim_x[1] : dim_x[0]; - KX = transpose_x ? dim_x[0] : dim_x[1]; - break; - case 3: - batchCountX = dim_x[0]; - M = transpose_x ? dim_x[2] : dim_x[1]; - KX = transpose_x ? dim_x[1] : dim_x[2]; - break; - default: - batchCountX = batch_count; - size_t mat_s = dim_x.size() - 2; - M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s]; - KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1]; - break; - } + auto mat_dim_x = math::GetMatDim(GetXDim(dim_x), 0, + context->Attrs().Get("transpose_X")); + auto mat_dim_y = math::GetMatDim(GetYDim(dim_y), 0, + context->Attrs().Get("transpose_Y")); - switch (dim_y.size()) { - case 1: - if (transpose_y) { - N = dim_y[0]; - KY = 1; - } else { - N = 1; - KY = dim_y[0]; - remove_final_dim = true; - } - break; - case 2: - KY = transpose_y ? dim_y[1] : dim_y[0]; - N = transpose_y ? dim_y[0] : dim_y[1]; - break; - case 3: - batchCountY = dim_y[0]; - KY = transpose_y ? dim_y[2] : dim_y[1]; - N = transpose_y ? dim_y[1] : dim_y[2]; - break; - default: - batchCountY = batch_count; - size_t mat_s = dim_y.size() - 2; - KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s]; - N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1]; + PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); + PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || + mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); + std::vector dim_out; + if (mat_dim_x.batch_size_ != 0) { + dim_out = framework::vectorize(dim_x); + dim_out[dim_out.size() - 2] = mat_dim_x.height_; + dim_out[dim_out.size() - 1] = mat_dim_y.width_; + } else if (mat_dim_y.batch_size_ != 0) { + dim_out = framework::vectorize(dim_y); + dim_out[dim_out.size() - 2] = mat_dim_x.height_; + dim_out[dim_out.size() - 1] = mat_dim_y.width_; + } else { + dim_out = {mat_dim_x.height_, mat_dim_y.width_}; } - PADDLE_ENFORCE_EQ( - KX, KY, - "First matrix's width must be equal with second matrix's height."); - if (batchCountX && batchCountY) { - PADDLE_ENFORCE_EQ( - batchCountX, batchCountY, - "When Input(X) and Input(Y) are both three dimensional, they " - "must have the same batch dimension."); + if (dim_x.size() == 1 && dim_out[dim_out.size() - 2] == 1) { + std::swap(dim_out[dim_out.size() - 2], dim_out[dim_out.size() - 1]); + dim_out.resize(dim_out.size() - 1); } - int batchCount = std::max(batchCountX, batchCountY); - std::vector dim_out; - if (batchCount) { - if (dim_x.size() > 3) { - dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end()); - } else { - dim_out.push_back(batchCount); - } + if (dim_y.size() == 1 && dim_out[dim_out.size() - 1] == 1) { + dim_out.resize(dim_out.size() - 1); } - if (!remove_initial_dim) { - dim_out.push_back(M); - } - if (!remove_final_dim) { - dim_out.push_back(N); - } - if (dim_out.size() == 0) { - // We don't support 0-dimensional Tensors (scalars), so instead - // treat the output as a Tensor of shape (1, ) in this case. - dim_out.push_back(1); + + if (dim_out.empty()) { + dim_out = {1}; } context->SetOutputDim("Out", framework::make_ddim(dim_out)); context->ShareLoD("X", /*->*/ "Out"); diff --git a/paddle/fluid/operators/matmul_op.h b/paddle/fluid/operators/matmul_op.h index f2e9cfdcdbf..7b484d124a7 100644 --- a/paddle/fluid/operators/matmul_op.h +++ b/paddle/fluid/operators/matmul_op.h @@ -15,55 +15,56 @@ limitations under the License. */ #pragma once #include #include +#include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/math/matmul.h" namespace paddle { namespace operators { -namespace matmul_detail { +inline framework::DDim GetXDim(const framework::DDim& x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return framework::make_ddim({1, x_dim[0]}); +} -using Tensor = framework::Tensor; -using DDim = framework::DDim; -using framework::make_ddim; -using framework::vectorize; +inline framework::DDim GetYDim(const framework::DDim& y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return framework::make_ddim({y_dim[0], 1}); +} template class MatMulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - const Tensor& x = *context.Input("X"); - const Tensor& y = *context.Input("Y"); - Tensor* out = context.Output("Out"); + auto& x = + detail::Ref(context.Input("X"), "Cannot find X"); + auto& y = + detail::Ref(context.Input("Y"), "Cannot find Y"); + auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); - bool transpose_x = context.Attr("transpose_X"); - bool transpose_y = context.Attr("transpose_Y"); - math::MatMulFunctor()( - context.template device_context(), x, transpose_x, y, - transpose_y, T(1), out, T(0)); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::GetMatDim(GetXDim(x.dims()), 0, + context.Attr("transpose_X")); + auto mat_dim_b = math::GetMatDim(GetYDim(y.dims()), 0, + context.Attr("transpose_Y")); + blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); } }; -template -inline Tensor Reshape(const Tensor& input, const DDim& dims) { - Tensor output; - output.ShareDataWith(input); - output.Resize(dims); - return output; -} - // Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Identity op if the tensor is not of rank 3. -template -Tensor CombineBatchAndM(const Tensor& input) { - Tensor output; - output.ShareDataWith(input); +inline framework::Tensor CombineBatchAndM(const framework::Tensor& input) { + auto output = input; auto in_dims = input.dims(); if (in_dims.size() == 3) { - std::vector out_dims = {in_dims[0] * in_dims[1], in_dims[2]}; - output.Resize(make_ddim(out_dims)); + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); } return output; } @@ -72,23 +73,57 @@ Tensor CombineBatchAndM(const Tensor& input) { // (Warning: This requires transposing data and writes into new memory.) // Identity op if the tensor is not of rank 3. template -Tensor CombineBatchAndN(const DeviceContext& context, const Tensor& input) { - Tensor output; +inline framework::Tensor CombineBatchAndN(const DeviceContext& context, + const framework::Tensor& input) { auto in_dims = input.dims(); - if (in_dims.size() == 3) { - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - output.mutable_data(context.GetPlace()); - std::vector axis = {1, 0, 2}; - math::Transpose trans; - trans(context, input, &output, axis); - std::vector out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - } else { - output.ShareDataWith(input); + if (in_dims.size() != 3) { + return input; } + framework::Tensor output; + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); + output.mutable_data(context.GetPlace()); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(context, input, &output, axis); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); + return output; } +inline void NormalizeTensorShape(framework::Tensor* x, + const math::MatDim& mat_dim_x) { + int64_t h, w; + h = mat_dim_x.height_; + w = mat_dim_x.width_; + if (mat_dim_x.trans_) { + std::swap(w, h); + } + if (mat_dim_x.batch_size_) { + x->Resize({mat_dim_x.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +inline void NormalizeXYOutTensorShape(framework::Tensor* x, + framework::Tensor* y, + framework::Tensor* out, bool trans_a, + bool trans_b) { + auto x_dim = GetXDim(x->dims()); + auto y_dim = GetYDim(y->dims()); + auto mat_dim_x = math::GetMatDim(x_dim, 0, trans_a); + auto mat_dim_y = math::GetMatDim(y_dim, 0, trans_b); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, mat_dim_y.width_}); + } + + NormalizeTensorShape(x, mat_dim_x); + NormalizeTensorShape(y, mat_dim_y); +} + // Using dimensional constraints on matrix multiplication, it is // straight-forward to check the following table for when X and Y // are both matrices. @@ -117,128 +152,91 @@ Tensor CombineBatchAndN(const DeviceContext& context, const Tensor& input) { template class MatMulGradKernel : public framework::OpKernel { public: + void MatMul(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, + framework::Tensor* out) const { + out->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::GetMatDim(a.dims(), 0, trans_a); + auto mat_dim_b = math::GetMatDim(b.dims(), 0, trans_b); + blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); + } + + void CalcInputGrad(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + bool is_combine_m_a, const framework::Tensor& b, + bool trans_b, bool is_combine_m_b, + framework::Tensor* out) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, out); + } else { + auto& ctx = context.template device_context(); + MatMul( + context, is_combine_m_a ? CombineBatchAndM(a) + : CombineBatchAndN(ctx, a), + trans_a, is_combine_m_b ? CombineBatchAndM(b) + : CombineBatchAndN(ctx, b), + trans_b, out); + } + } + void Compute(const framework::ExecutionContext& context) const override { - const Tensor& x = *context.Input("X"); - const Tensor& y = *context.Input("Y"); - const Tensor& dout = *context.Input(framework::GradVarName("Out")); - Tensor* dx = context.Output(framework::GradVarName("X")); - Tensor* dy = context.Output(framework::GradVarName("Y")); + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = + *context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + auto* dy = context.Output(framework::GradVarName("Y")); bool transpose_x = context.Attr("transpose_X"); bool transpose_y = context.Attr("transpose_Y"); - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - - // If X is a vector, reshape it to a matrix. - if (x_dims.size() == 1) { - x_dims.insert(x_dims.begin(), 1); - } - - // If Y is a vector, reshape it to a matrix. - if (y_dims.size() == 1) { - y_dims.push_back(1); + NormalizeXYOutTensorShape(&x, &y, &dout, transpose_x, transpose_y); + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } } - int batch_count = 0; - // The first rank-2 dimensions are accumulated on the batch_count, and the - // last two dimensions are used for matrix multiplication. - if (x_dims.size() > 3) { - batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, - std::multiplies()); - } - // Fix the dOut dimensions. - int M = 0, N = 0, batchCountX = 0, batchCountY = 0; - - switch (x_dims.size()) { - case 2: - M = transpose_x ? x_dims[1] : x_dims[0]; - break; - case 3: - batchCountX = x_dims[0]; - M = transpose_x ? x_dims[2] : x_dims[1]; - break; - default: - batchCountX = batch_count; - size_t mat_s = x_dims.size() - 2; - M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s]; + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } } - switch (y_dims.size()) { - case 2: - N = transpose_y ? y_dims[0] : y_dims[1]; - break; - case 3: - batchCountY = y_dims[0]; - N = transpose_y ? y_dims[1] : y_dims[2]; - break; - default: - batchCountY = batch_count; - size_t mat_s = y_dims.size() - 2; - N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; - } - if (batchCountX && batchCountY) { - PADDLE_ENFORCE_EQ( - batchCountX, batchCountY, - "When Input(X) and Input(Y) are both three dimensional, they " - "must have the same batch dimension."); - } - int batchCount = std::max(batchCountX, batchCountY); - std::vector dout_dims = {M, N}; - if (batchCount) { - if (x_dims.size() > 3) { - dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2); - } else { - dout_dims.insert(dout_dims.begin(), batchCount); - } + if (transpose_x && transpose_y) { + CalcInputGrad(context, y, true, true, dout, true, false, dx); + CalcInputGrad(context, dout, true, true, x, true, false, dy); + } else if (transpose_x && !transpose_y) { + CalcInputGrad(context, y, false, false, dout, true, false, dx); + CalcInputGrad(context, x, false, false, dout, false, true, dy); + } else if (!transpose_x && transpose_y) { + CalcInputGrad(context, dout, false, false, y, false, true, dx); + CalcInputGrad(context, dout, true, true, x, false, true, dy); + } else { + CalcInputGrad(context, dout, false, false, y, true, false, dx); + CalcInputGrad(context, x, true, true, dout, false, true, dy); } - Tensor X = Reshape(x, make_ddim(x_dims)); - Tensor Y = Reshape(y, make_ddim(y_dims)); - Tensor dOut = Reshape(dout, make_ddim(dout_dims)); - auto& dev_ctx = context.template device_context(); if (dx) { - dx->mutable_data(context.GetPlace()); - const Tensor& dOut_for_dX = - (x_dims.size() == 2 && y_dims.size() == 3) - ? CombineBatchAndN(dev_ctx, dOut) - : dOut; - if (x_dims.size() == 2 && y_dims.size() == 3) { - Y = transpose_y ? CombineBatchAndM(Y) - : CombineBatchAndN(dev_ctx, Y); - } - if (transpose_x) { - math::MatMulFunctor()( - dev_ctx, Y, transpose_y, dOut_for_dX, transpose_x, T(1), dx, T(0)); - } else { - math::MatMulFunctor()( - dev_ctx, dOut_for_dX, transpose_x, Y, !transpose_y, T(1), dx, T(0)); + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); } } - if (dy) { - dy->mutable_data(context.GetPlace()); - const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3) - ? CombineBatchAndM(dOut) - : dOut; - if (y_dims.size() == 2 && x_dims.size() == 3) { - X = transpose_x ? CombineBatchAndN(dev_ctx, X) - : CombineBatchAndM(X); - dOut = CombineBatchAndM(dOut); - } - if (transpose_y) { - math::MatMulFunctor()( - dev_ctx, dOut_for_dY, transpose_y, X, transpose_x, T(1), dy, T(0)); - } else { - math::MatMulFunctor()( - dev_ctx, X, !transpose_x, dOut_for_dY, transpose_y, T(1), dy, T(0)); + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); } } } }; -} // namespace matmul_detail - -using matmul_detail::MatMulKernel; -using matmul_detail::MatMulGradKernel; } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_matmul_op.py b/python/paddle/fluid/tests/unittests/test_matmul_op.py index 44ac4683891..cae2c8fa87d 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_op.py @@ -111,21 +111,24 @@ class Generator(object): # Generate test cases for all possibilities -for dim_X in [1, 2, 3]: - for dim_Y in [1, 2, 3]: - for transpose_X in [False, True]: - for transpose_Y in [False, True]: - test_name = ( - 'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( - dim_X, dim_Y, transpose_X, transpose_Y)) - shape_X, shape_Y = generate_compatible_shapes( - dim_X, dim_Y, transpose_X, transpose_Y) - globals()[test_name] = type(test_name, (Generator, OpTest), { - 'shape_X': shape_X, - 'shape_Y': shape_Y, - 'transpose_X': transpose_X, - 'transpose_Y': transpose_Y, - }) +def inject_test(dim_x, dim_y, trans_x, trans_y): + test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( + dim_x, dim_y, trans_x, trans_y)) + shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, + trans_y) + globals()[test_name] = type(test_name, (Generator, OpTest), { + 'shape_X': shape_x, + 'shape_Y': shape_y, + 'transpose_X': trans_x, + 'transpose_Y': trans_y, + }) + + +for dim_X in (1, 2, 3): + for dim_Y in (1, 2, 3): + for transose_x in (False, True): + for transose_y in (False, True): + inject_test(dim_X, dim_Y, transose_x, transose_y) # Test case n-dim @@ -149,7 +152,7 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y): return shape_X, shape_Y -# Test case n-dim +# # Test case n-dim for dim in [4]: for transpose_X in [False, True]: for transpose_Y in [False, True]: -- GitLab