diff --git a/paddle/fluid/operators/math/blas.cc b/paddle/fluid/operators/math/blas.cc index 3eeb77546b97a0337b46216d837a4f4cff12c89f..6a143b3c056455595fdedc131b0c5f4ee756e1e0 100644 --- a/paddle/fluid/operators/math/blas.cc +++ b/paddle/fluid/operators/math/blas.cc @@ -13,10 +13,40 @@ // limitations under the License. #include "paddle/fluid/operators/math/blas.h" + +#include namespace paddle { namespace operators { namespace math { -// Do nothing. Blas is a header only library. +MatDescriptor CreateMatrixDescriptor(const framework::DDim &tensor_dim, + int num_flatten_cols, bool trans) { + PADDLE_ENFORCE_GT(tensor_dim.size(), 1); + MatDescriptor retv; + if (num_flatten_cols > 1) { + auto flatten_dim = framework::flatten_to_2d(tensor_dim, num_flatten_cols); + retv.height_ = flatten_dim[0]; + retv.width_ = flatten_dim[1]; + } else { + if (tensor_dim.size() == 2) { + retv.height_ = tensor_dim[0]; + retv.width_ = tensor_dim[1]; + } else { + auto dim_vec = framework::vectorize(tensor_dim); + retv.batch_size_ = 1; + for (size_t i = 0; i < dim_vec.size() - 2; ++i) { + retv.batch_size_ *= dim_vec[i]; + } + retv.height_ = dim_vec[dim_vec.size() - 2]; + retv.width_ = dim_vec[dim_vec.size() - 1]; + retv.stride_ = retv.height_ * retv.width_; + } + } + if (trans) { + std::swap(retv.width_, retv.height_); + } + retv.trans_ = trans; + return retv; +} } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 5cd2f855d1135e6dd8343efdaa9855d2526a3520..dabde43850db770d286b13cacd32bee181328d5c 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -46,6 +46,50 @@ namespace paddle { namespace operators { namespace math { +/** + * Matrix Descriptor of a memory buffer. + * + * It is used for Blas::MatMul. MatMul operator can be batched. + * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a + * `batch_size` times of GEMM. The batched GEMM could be faster base on the + * implementation of the blas library. The batch size could be zero. If any + * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g., + * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be + * [BatchSize, H1, W2] + * + * The boolean flag, `trans`, describe the memory is the transpose of matrix or + * not. If the trans is true, the last two dims of matrix are transposed. The + * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height]. + * + * The MatDescriptor is not only the dimension or shape of a matrix, it also + * contains the layout, stride of matrix. It is clearer to have a structure than + * reuse `DDim`. + */ +struct MatDescriptor { + int64_t height_; + int64_t width_; + int64_t stride_{0}; + int64_t batch_size_{0}; + bool trans_; +}; + +/** + * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose + * flag + * + * @param tensor_dim: The dimension of the tensor. The rank of this dimension + * must larger than 1. + * + * @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first + * dimension(column length) will be the product of tensor's first `num_col_dims` + * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the + * batch_size of descriptor. + * + * @param trans: True if the matrix is transposed. + */ +extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim, + int num_flatten_cols, bool trans); + template class Blas { public: @@ -90,6 +134,11 @@ class Blas { int K, T alpha, const T* A, const T* B, T beta, T* C, int batchCount, int64_t strideA, int64_t strideB) const; + template + void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a, + const framework::Tensor& mat_b, const MatDescriptor& dim_b, + T alpha, framework::Tensor* mat_out, T beta) const; + private: const DeviceContext& context_; }; diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 7360cc0a90da499c372c6fb3f8d40a26f9093dd8..577cbe3beb806ffcb2f1a7d7a469402be9b69224 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -180,6 +180,31 @@ void Blas::BatchedGEMM( #endif } +template +template +void Blas::MatMul(const framework::Tensor &mat_a, + const MatDescriptor &dim_a, + const framework::Tensor &mat_b, + const MatDescriptor &dim_b, T alpha, + framework::Tensor *mat_out, T beta) const { + PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_); + CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; + if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { + this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, + dim_a.width_, alpha, mat_a.data(), + mat_b.data(), beta, mat_out->data()); + } else { + PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || + dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); + this->template BatchedGEMM( + transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, + mat_a.data(), mat_b.data(), beta, mat_out->data(), + dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, + dim_a.stride_, dim_b.stride_); + } +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/matmul.h b/paddle/fluid/operators/math/matmul.h deleted file mode 100644 index 87fd38a324e007bcc939c31b6ae8e5d38c3e658c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/matmul.h +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/operators/math/blas.h" - -namespace paddle { -namespace operators { -namespace math { - -// Implements the logic of numpy matmul: -// https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html -// -// but allowing also for a, b to be transposed -// -// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported -// yet. -template -class MatMulFunctor { - public: - void operator()(const DeviceContext& context, const framework::Tensor& a, - bool trans_a, const framework::Tensor& b, bool trans_b, - T alpha, framework::Tensor* out, T beta) { - auto dim_a = a.dims(); - auto dim_b = b.dims(); - - PADDLE_ENFORCE(a.place() == b.place() && b.place() == out->place(), - "Tensors must all be in the same place."); - PADDLE_ENFORCE_GE(dim_a.size(), 1, - "Input tensor a must be at least 1-dimensional."); - PADDLE_ENFORCE_GE(dim_b.size(), 1, - "Input tensor b must be at least 1-dimensional."); - - std::vector out_dim; - int64_t batch_count = 1; - if (dim_a.size() > 3) { - PADDLE_ENFORCE(dim_b.size() == dim_a.size(), - "The dimensions of X and Y must be the same, and both of " - "them should be %d-dimensional.", - dim_b.size()); - // The first rank-2 dimensions are accumulated on the batch_count, and the - // last two dimensions are used for matrix multiplication. - for (int j = 0; j < dim_a.size() - 2; ++j) { - PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j], - "The %d-th dimension of X and Y must be the same.", - j); - out_dim.push_back(dim_a[j]); - batch_count *= dim_a[j]; - } - } - - int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0, - strideA = 0, strideB = 0; - - switch (dim_a.size()) { - case 1: - // similar to np.matmul: - // prepend dimension 1 (no transpose) or append dimension 1 (transpose) - M = trans_a ? dim_a[0] : 1; - kA = trans_a ? 1 : dim_a[0]; - break; - case 2: - M = trans_a ? dim_a[1] : dim_a[0]; - kA = trans_a ? dim_a[0] : dim_a[1]; - break; - case 3: - batchCountA = dim_a[0]; - M = trans_a ? dim_a[2] : dim_a[1]; - kA = trans_a ? dim_a[1] : dim_a[2]; - strideA = M * kA; - break; - default: - batchCountA = batch_count; - size_t mat_s = dim_a.size() - 2; - M = trans_a ? dim_a[mat_s + 1] : dim_a[mat_s]; - kA = trans_a ? dim_a[mat_s] : dim_a[mat_s + 1]; - strideA = M * kA; - } - - switch (dim_b.size()) { - case 1: - // similar to np.matmul: - // append dimension 1 (no transpose) or prepend dimension 1 (transpose) - kB = trans_b ? 1 : dim_b[0]; - N = trans_b ? dim_b[0] : 1; - break; - case 2: - kB = trans_b ? dim_b[1] : dim_b[0]; - N = trans_b ? dim_b[0] : dim_b[1]; - break; - case 3: - batchCountB = dim_b[0]; - kB = trans_b ? dim_b[2] : dim_b[1]; - N = trans_b ? dim_b[1] : dim_b[2]; - strideB = kB * N; - break; - default: - batchCountB = batch_count; - size_t mat_s = dim_b.size() - 2; - kB = trans_b ? dim_b[mat_s + 1] : dim_b[mat_s]; - N = trans_b ? dim_b[mat_s] : dim_b[mat_s + 1]; - strideB = kB * N; - } - - PADDLE_ENFORCE_EQ( - kA, kB, - "First matrix's width must be equal with second matrix's height."); - if (batchCountA && batchCountB) { - PADDLE_ENFORCE_EQ( - batchCountA, batchCountB, - "When input tensors a and b are both batched, they must have the " - "same batch dimension."); - } - int batchCount = std::max(batchCountA, batchCountB); - - CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - - auto blas = GetBlas(context); - - if (!batchCount) { - // regular matrix multiplication - blas.GEMM(transA, transB, M, N, kA, alpha, a.data(), b.data(), beta, - out->data()); - } else { - // batched matrix multiplication - blas.BatchedGEMM(transA, transB, M, N, kA, alpha, a.data(), - b.data(), beta, out->data(), batchCount, strideA, - strideB); - } - } -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index e5d33fbc36438f97ff5b604e4efdbfbfa91fcee4..da21b8ad7d4e353e1dbe98fde1fbac1b0d37fd5d 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -12,14 +12,257 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/matmul_op.h" #include +#include #include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { +/** + * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the + * original x_dim is returned. + */ +static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return framework::make_ddim({1, x_dim[0]}); +} + +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return framework::make_ddim({y_dim[0], 1}); +} + +template +class MatMulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& x = + detail::Ref(context.Input("X"), "Cannot find X"); + auto& y = + detail::Ref(context.Input("Y"), "Cannot find Y"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor( + RowMatrixFromVector(x.dims()), 0, context.Attr("transpose_X")); + auto mat_dim_b = math::CreateMatrixDescriptor( + ColumnMatrixFromVector(y.dims()), 0, context.Attr("transpose_Y")); + blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); + } +}; + +// Reshape a rank-3 tensor from P x M x N to (P * M) x N. +// Identity op if the tensor is not of rank 3. +static framework::Tensor FoldInitDims(const framework::Tensor& input) { + auto output = input; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + } + return output; +} + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, + const framework::Tensor& input) { + auto in_dims = input.dims(); + if (in_dims.size() != 3) { + return input; + } + framework::Tensor output; + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); + output.mutable_data(context.GetPlace()); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(context, input, &output, axis); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); + + return output; +} + +/** + * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. + * + * The shape would be [BatchSize, H, W] or [H, W]. + * If transposed, `H,W` will be swapped. + */ +static void ReshapeTensorIntoMatrixSequence( + framework::Tensor* x, const math::MatDescriptor& descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +/** + * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor + * Out = matmul(x, y) + * + * This method will first calculate X,Y matrix sequence, and then calculate + * the out shape. + * + * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] + * The out = [BatchSize, H1, W2] + * + * If there is no batch size in `X` and `Y`, the out will be [H1, W2] + * If any of `X` and `Y` has batch size BatchSize, the out will have the + * BatchSize. + */ +static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, + framework::Tensor* y, + framework::Tensor* out, bool trans_x, + bool trans_y) { + auto x_dim = RowMatrixFromVector(x->dims()); + auto y_dim = ColumnMatrixFromVector(y->dims()); + auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, mat_dim_y.width_}); + } + + ReshapeTensorIntoMatrixSequence(x, mat_dim_x); + ReshapeTensorIntoMatrixSequence(y, mat_dim_y); +} + +// Using dimensional constraints on matrix multiplication, it is +// straight-forward to check the following table for when X and Y +// are both matrices. +// +// transpose_X | False | True | False | True +// transpose_Y | False | False | True | True +// -----------+----------+----------+----------+----------- +// dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T +// dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T +// +// When X is a vector of size K, we treat it instead as a matrix of shape +// (1, K). Similarly, when Y is a vector of size K, we treat it instead as +// a matrix of shape (K, 1). +// +// When X and Y are both 3-dimensional tensors, then the first dimension +// the batch dimension can be ignored and the exact same formulas apply +// as for two matrices. +// +// Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end +// up with formulas like +// +// dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj} +// +// To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N +// to X: (P * M) x K, dOut: (P * M) x N. +template +class MatMulGradKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, + framework::Tensor* out) const { + out->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); + } + + void CalcInputGrad(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor& b, + bool trans_b, bool is_fold_init_dims_b, + framework::Tensor* out) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, out); + } else { + auto& ctx = context.template device_context(); + MatMul(context, is_fold_init_dims_a + ? FoldInitDims(a) + : FoldHeadAndLastDims(ctx, a), + trans_a, is_fold_init_dims_b + ? FoldInitDims(b) + : FoldHeadAndLastDims(ctx, b), + trans_b, out); + } + } + + void Compute(const framework::ExecutionContext& context) const override { + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = + *context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + auto* dy = context.Output(framework::GradVarName("Y")); + bool transpose_x = context.Attr("transpose_X"); + bool transpose_y = context.Attr("transpose_Y"); + + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } -using framework::Tensor; + if (transpose_x && transpose_y) { + CalcInputGrad(context, y, true, true, dout, true, false, dx); + CalcInputGrad(context, dout, true, true, x, true, false, dy); + } else if (transpose_x) { + CalcInputGrad(context, y, false, false, dout, true, false, dx); + CalcInputGrad(context, x, false, false, dout, false, true, dy); + } else if (transpose_y) { + CalcInputGrad(context, dout, false, false, y, false, true, dx); + CalcInputGrad(context, dout, true, true, x, false, true, dy); + } else { + CalcInputGrad(context, dout, false, false, y, true, false, dx); + CalcInputGrad(context, x, true, true, dout, false, true, dy); + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } + } + } +}; class MatMulOp : public framework::OperatorWithKernel { public: @@ -36,121 +279,41 @@ class MatMulOp : public framework::OperatorWithKernel { auto dim_x = context->GetInputDim("X"); auto dim_y = context->GetInputDim("Y"); - bool transpose_x = context->Attrs().Get("transpose_X"); - bool transpose_y = context->Attrs().Get("transpose_Y"); - - PADDLE_ENFORCE_GE(dim_x.size(), 1, - "Input tensor X must be at least 1-dimensional."); - PADDLE_ENFORCE_GE(dim_y.size(), 1, - "Input tensor Y must be at least 1-dimensional."); - - std::vector out_dim; - int64_t batch_count = 1; - if (dim_x.size() > 3) { - PADDLE_ENFORCE_EQ( - dim_y.size(), dim_x.size(), - "The dimensions of X and Y must be the same, and both of " - "them should be %d-dimensional.", - dim_x.size()); - - // The first rank-2 dimensions are accumulated on the batch_count, and the - // last two dimensions are used for matrix multiplication. - for (int j = 0; j < dim_x.size() - 2; ++j) { - PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j], - "The %d-th dimension of X and Y must be the same.", - j); - out_dim.push_back(dim_x[j]); - batch_count *= dim_x[j]; - } - } - int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; - bool remove_initial_dim = false, remove_final_dim = false; - - switch (dim_x.size()) { - case 1: - if (transpose_x) { - M = dim_x[0]; - KX = 1; - } else { - M = 1; - KX = dim_x[0]; - remove_initial_dim = true; - } - break; - case 2: - M = transpose_x ? dim_x[1] : dim_x[0]; - KX = transpose_x ? dim_x[0] : dim_x[1]; - break; - case 3: - batchCountX = dim_x[0]; - M = transpose_x ? dim_x[2] : dim_x[1]; - KX = transpose_x ? dim_x[1] : dim_x[2]; - break; - default: - batchCountX = batch_count; - size_t mat_s = dim_x.size() - 2; - M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s]; - KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1]; - break; - } + auto mat_dim_x = + math::CreateMatrixDescriptor(RowMatrixFromVector(dim_x), 0, + context->Attrs().Get("transpose_X")); + auto mat_dim_y = + math::CreateMatrixDescriptor(ColumnMatrixFromVector(dim_y), 0, + context->Attrs().Get("transpose_Y")); - switch (dim_y.size()) { - case 1: - if (transpose_y) { - N = dim_y[0]; - KY = 1; - } else { - N = 1; - KY = dim_y[0]; - remove_final_dim = true; - } - break; - case 2: - KY = transpose_y ? dim_y[1] : dim_y[0]; - N = transpose_y ? dim_y[0] : dim_y[1]; - break; - case 3: - batchCountY = dim_y[0]; - KY = transpose_y ? dim_y[2] : dim_y[1]; - N = transpose_y ? dim_y[1] : dim_y[2]; - break; - default: - batchCountY = batch_count; - size_t mat_s = dim_y.size() - 2; - KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s]; - N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1]; + PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); + PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || + mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); + std::vector dim_out; + if (mat_dim_x.batch_size_ != 0) { + dim_out = framework::vectorize(dim_x); + dim_out[dim_out.size() - 2] = mat_dim_x.height_; + dim_out[dim_out.size() - 1] = mat_dim_y.width_; + } else if (mat_dim_y.batch_size_ != 0) { + dim_out = framework::vectorize(dim_y); + dim_out[dim_out.size() - 2] = mat_dim_x.height_; + dim_out[dim_out.size() - 1] = mat_dim_y.width_; + } else { + dim_out = {mat_dim_x.height_, mat_dim_y.width_}; } - PADDLE_ENFORCE_EQ( - KX, KY, - "First matrix's width must be equal with second matrix's height."); - if (batchCountX && batchCountY) { - PADDLE_ENFORCE_EQ( - batchCountX, batchCountY, - "When Input(X) and Input(Y) are both three dimensional, they " - "must have the same batch dimension."); + if (dim_x.size() == 1 && dim_out[dim_out.size() - 2] == 1) { + std::swap(dim_out[dim_out.size() - 2], dim_out[dim_out.size() - 1]); + dim_out.resize(dim_out.size() - 1); } - int batchCount = std::max(batchCountX, batchCountY); - std::vector dim_out; - if (batchCount) { - if (dim_x.size() > 3) { - dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end()); - } else { - dim_out.push_back(batchCount); - } + if (dim_y.size() == 1 && dim_out[dim_out.size() - 1] == 1) { + dim_out.resize(dim_out.size() - 1); } - if (!remove_initial_dim) { - dim_out.push_back(M); - } - if (!remove_final_dim) { - dim_out.push_back(N); - } - if (dim_out.size() == 0) { - // We don't support 0-dimensional Tensors (scalars), so instead - // treat the output as a Tensor of shape (1, ) in this case. - dim_out.push_back(1); + + if (dim_out.empty()) { + dim_out = {1}; } context->SetOutputDim("Out", framework::make_ddim(dim_out)); context->ShareLoD("X", /*->*/ "Out"); @@ -233,15 +396,40 @@ class MatMulOpGrad : public framework::OperatorWithKernel { } }; +class MatMulOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* retv = new framework::OpDesc(); + retv->SetType("matmul_grad"); + retv->SetInput("X", Input("X")); + retv->SetInput("Y", Input("Y")); + retv->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), InputGrad("X")); + retv->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); + retv->SetAttrMap(Attrs()); + return std::unique_ptr(retv); + } +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::MatMulOpGradMaker); REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad); REGISTER_OP_CPU_KERNEL( matmul, ops::MatMulKernel); REGISTER_OP_CPU_KERNEL( matmul_grad, ops::MatMulGradKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL( + matmul, ops::MatMulKernel); +REGISTER_OP_CUDA_KERNEL( + matmul_grad, + ops::MatMulGradKernel); +#endif diff --git a/paddle/fluid/operators/matmul_op.cu.cc b/paddle/fluid/operators/matmul_op.cu.cc deleted file mode 100644 index e021bbe645399e410cde5c3ff7035d4d68c71744..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/matmul_op.cu.cc +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/matmul_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - matmul, ops::MatMulKernel); -REGISTER_OP_CUDA_KERNEL( - matmul_grad, - ops::MatMulGradKernel); diff --git a/paddle/fluid/operators/matmul_op.h b/paddle/fluid/operators/matmul_op.h deleted file mode 100644 index f2e9cfdcdbf93326ae193776a7d5f6a324373603..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/matmul_op.h +++ /dev/null @@ -1,244 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/math/matmul.h" - -namespace paddle { -namespace operators { -namespace matmul_detail { - -using Tensor = framework::Tensor; -using DDim = framework::DDim; -using framework::make_ddim; -using framework::vectorize; - -template -class MatMulKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor& x = *context.Input("X"); - const Tensor& y = *context.Input("Y"); - Tensor* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - bool transpose_x = context.Attr("transpose_X"); - bool transpose_y = context.Attr("transpose_Y"); - - math::MatMulFunctor()( - context.template device_context(), x, transpose_x, y, - transpose_y, T(1), out, T(0)); - } -}; - -template -inline Tensor Reshape(const Tensor& input, const DDim& dims) { - Tensor output; - output.ShareDataWith(input); - output.Resize(dims); - return output; -} - -// Reshape a rank-3 tensor from P x M x N to (P * M) x N. -// Identity op if the tensor is not of rank 3. -template -Tensor CombineBatchAndM(const Tensor& input) { - Tensor output; - output.ShareDataWith(input); - auto in_dims = input.dims(); - if (in_dims.size() == 3) { - std::vector out_dims = {in_dims[0] * in_dims[1], in_dims[2]}; - output.Resize(make_ddim(out_dims)); - } - return output; -} - -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -Tensor CombineBatchAndN(const DeviceContext& context, const Tensor& input) { - Tensor output; - auto in_dims = input.dims(); - if (in_dims.size() == 3) { - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - output.mutable_data(context.GetPlace()); - std::vector axis = {1, 0, 2}; - math::Transpose trans; - trans(context, input, &output, axis); - std::vector out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - } else { - output.ShareDataWith(input); - } - return output; -} - -// Using dimensional constraints on matrix multiplication, it is -// straight-forward to check the following table for when X and Y -// are both matrices. -// -// transpose_X | False | True | False | True -// transpose_Y | False | False | True | True -// -----------+----------+----------+----------+----------- -// dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T -// dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T -// -// When X is a vector of size K, we treat it instead as a matrix of shape -// (1, K). Similarly, when Y is a vector of size K, we treat it instead as -// a matrix of shape (K, 1). -// -// When X and Y are both 3-dimensional tensors, then the first dimension -// the batch dimension can be ignored and the exact same formulas apply -// as for two matrices. -// -// Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end -// up with formulas like -// -// dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj} -// -// To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N -// to X: (P * M) x K, dOut: (P * M) x N. -template -class MatMulGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor& x = *context.Input("X"); - const Tensor& y = *context.Input("Y"); - const Tensor& dout = *context.Input(framework::GradVarName("Out")); - Tensor* dx = context.Output(framework::GradVarName("X")); - Tensor* dy = context.Output(framework::GradVarName("Y")); - bool transpose_x = context.Attr("transpose_X"); - bool transpose_y = context.Attr("transpose_Y"); - - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - - // If X is a vector, reshape it to a matrix. - if (x_dims.size() == 1) { - x_dims.insert(x_dims.begin(), 1); - } - - // If Y is a vector, reshape it to a matrix. - if (y_dims.size() == 1) { - y_dims.push_back(1); - } - - int batch_count = 0; - // The first rank-2 dimensions are accumulated on the batch_count, and the - // last two dimensions are used for matrix multiplication. - if (x_dims.size() > 3) { - batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, - std::multiplies()); - } - // Fix the dOut dimensions. - int M = 0, N = 0, batchCountX = 0, batchCountY = 0; - - switch (x_dims.size()) { - case 2: - M = transpose_x ? x_dims[1] : x_dims[0]; - break; - case 3: - batchCountX = x_dims[0]; - M = transpose_x ? x_dims[2] : x_dims[1]; - break; - default: - batchCountX = batch_count; - size_t mat_s = x_dims.size() - 2; - M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s]; - } - - switch (y_dims.size()) { - case 2: - N = transpose_y ? y_dims[0] : y_dims[1]; - break; - case 3: - batchCountY = y_dims[0]; - N = transpose_y ? y_dims[1] : y_dims[2]; - break; - default: - batchCountY = batch_count; - size_t mat_s = y_dims.size() - 2; - N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; - } - if (batchCountX && batchCountY) { - PADDLE_ENFORCE_EQ( - batchCountX, batchCountY, - "When Input(X) and Input(Y) are both three dimensional, they " - "must have the same batch dimension."); - } - int batchCount = std::max(batchCountX, batchCountY); - std::vector dout_dims = {M, N}; - if (batchCount) { - if (x_dims.size() > 3) { - dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2); - } else { - dout_dims.insert(dout_dims.begin(), batchCount); - } - } - Tensor X = Reshape(x, make_ddim(x_dims)); - Tensor Y = Reshape(y, make_ddim(y_dims)); - Tensor dOut = Reshape(dout, make_ddim(dout_dims)); - - auto& dev_ctx = context.template device_context(); - if (dx) { - dx->mutable_data(context.GetPlace()); - const Tensor& dOut_for_dX = - (x_dims.size() == 2 && y_dims.size() == 3) - ? CombineBatchAndN(dev_ctx, dOut) - : dOut; - if (x_dims.size() == 2 && y_dims.size() == 3) { - Y = transpose_y ? CombineBatchAndM(Y) - : CombineBatchAndN(dev_ctx, Y); - } - if (transpose_x) { - math::MatMulFunctor()( - dev_ctx, Y, transpose_y, dOut_for_dX, transpose_x, T(1), dx, T(0)); - } else { - math::MatMulFunctor()( - dev_ctx, dOut_for_dX, transpose_x, Y, !transpose_y, T(1), dx, T(0)); - } - } - - if (dy) { - dy->mutable_data(context.GetPlace()); - const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3) - ? CombineBatchAndM(dOut) - : dOut; - if (y_dims.size() == 2 && x_dims.size() == 3) { - X = transpose_x ? CombineBatchAndN(dev_ctx, X) - : CombineBatchAndM(X); - dOut = CombineBatchAndM(dOut); - } - if (transpose_y) { - math::MatMulFunctor()( - dev_ctx, dOut_for_dY, transpose_y, X, transpose_x, T(1), dy, T(0)); - } else { - math::MatMulFunctor()( - dev_ctx, X, !transpose_x, dOut_for_dY, transpose_y, T(1), dy, T(0)); - } - } - } -}; -} // namespace matmul_detail - -using matmul_detail::MatMulKernel; -using matmul_detail::MatMulGradKernel; - -} // namespace operators -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_matmul_op.py b/python/paddle/fluid/tests/unittests/test_matmul_op.py index 44ac4683891ffd3141a126740f4fddb47550e183..cae2c8fa87d9857de8f26cf4962d9370eca66243 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_op.py @@ -111,21 +111,24 @@ class Generator(object): # Generate test cases for all possibilities -for dim_X in [1, 2, 3]: - for dim_Y in [1, 2, 3]: - for transpose_X in [False, True]: - for transpose_Y in [False, True]: - test_name = ( - 'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( - dim_X, dim_Y, transpose_X, transpose_Y)) - shape_X, shape_Y = generate_compatible_shapes( - dim_X, dim_Y, transpose_X, transpose_Y) - globals()[test_name] = type(test_name, (Generator, OpTest), { - 'shape_X': shape_X, - 'shape_Y': shape_Y, - 'transpose_X': transpose_X, - 'transpose_Y': transpose_Y, - }) +def inject_test(dim_x, dim_y, trans_x, trans_y): + test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( + dim_x, dim_y, trans_x, trans_y)) + shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, + trans_y) + globals()[test_name] = type(test_name, (Generator, OpTest), { + 'shape_X': shape_x, + 'shape_Y': shape_y, + 'transpose_X': trans_x, + 'transpose_Y': trans_y, + }) + + +for dim_X in (1, 2, 3): + for dim_Y in (1, 2, 3): + for transose_x in (False, True): + for transose_y in (False, True): + inject_test(dim_X, dim_Y, transose_x, transose_y) # Test case n-dim @@ -149,7 +152,7 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y): return shape_X, shape_Y -# Test case n-dim +# # Test case n-dim for dim in [4]: for transpose_X in [False, True]: for transpose_Y in [False, True]: