diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 77a1e22b41e8dcd1fe78f3c4730653dee04db80e..aad1357598c629a4edfe0ad9b23f0241093a2522 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -130,6 +130,87 @@ void matmul( matrix_b.data(), beta, matrix_out->data()); } +#ifdef PADDLE_USE_MKLML +// Use cblas_{s,d}gemm_batched if available: Run with 1 group of size batchSize. +template <> +void batched_gemm( + const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const float alpha, const float* A, const float* B, const float beta, + float* C, const int batchCount, const int strideA, const int strideB) { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA]; + b_array[k] = &B[k * strideB]; + c_array[k] = &C[k * M * N]; + } + cblas_sgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, + a_array.data(), &lda, b_array.data(), &ldb, &beta, + c_array.data(), &ldc, 1 /* group_count */, &batchCount); +} + +template <> +void batched_gemm( + const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const double alpha, const double* A, const double* B, const double beta, + double* C, const int batchCount, const int strideA, const int strideB) { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA]; + b_array[k] = &B[k * strideB]; + c_array[k] = &C[k * M * N]; + } + cblas_dgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, + a_array.data(), &lda, b_array.data(), &ldb, &beta, + c_array.data(), &ldc, 1 /* group_count */, &batchCount); +} +#else +// The below is a naive but correct serial implementation that just loops +// over the batch dimension. This is a fallback for when the batched gemm +// functions of Intel MKL are not available. In the future, this computation +// should be parallelized. +template <> +void batched_gemm( + const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const float alpha, const float* A, const float* B, const float beta, + float* C, const int batchCount, const int strideA, const int strideB) { + for (int k = 0; k < batchCount; ++k) { + const float* Ak = &A[k * strideA]; + const float* Bk = &B[k * strideB]; + float* Ck = &C[k * M * N]; + gemm(context, transA, transB, M, N, K, alpha, Ak, + Bk, beta, Ck); + } +} + +template <> +void batched_gemm( + const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const double alpha, const double* A, const double* B, const double beta, + double* C, const int batchCount, const int strideA, const int strideB) { + for (int k = 0; k < batchCount; ++k) { + const double* Ak = &A[k * strideA]; + const double* Bk = &B[k * strideB]; + double* Ck = &C[k * M * N]; + gemm(context, transA, transB, M, N, K, alpha, + Ak, Bk, beta, Ck); + } +} +#endif + template struct SetConstant; } // namespace math diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 7fbc03acf22231a6fa386aa67e43f738eadb18d3..5583683c6e12b88ba81015aef9161913de261ef2 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -155,6 +155,54 @@ void matmul( matrix_b.data(), beta, matrix_out->data()); } +template <> +void batched_gemm( + const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const float alpha, const float* A, const float* B, const float beta, + float* C, const int batchCount, const int strideA, const int strideB) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int strideC = M * N; + + PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( + reinterpret_cast(context) + .cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, strideB, A, lda, strideA, + &beta, C, ldc, strideC, batchCount)); +} + +template <> +void batched_gemm( + const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const double alpha, const double* A, const double* B, const double beta, + double* C, const int batchCount, const int strideA, const int strideB) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int strideC = M * N; + + PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( + reinterpret_cast(context) + .cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, strideB, A, lda, strideA, + &beta, C, ldc, strideC, batchCount)); +} + template struct SetConstant; } // namespace math diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 6f92d83aabbc77f7ea7d4159869e07126b270740..9777ebfd156709a370be2cb4ba0077ac7c6735fb 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -63,7 +63,7 @@ namespace math { // Support continuous memory now // If transA = N, and transB = N -// Then matrixA: M * K, matrixB: K * N matrixC : M * N +// Then matrixA: M * K, matrixB: K * N, matrixC : M * N // For more detailed info, please refer to // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html template @@ -85,6 +85,14 @@ void matmul(const platform::DeviceContext& context, const framework::Tensor& matrix_b, bool trans_b, T alpha, framework::Tensor* matrix_out, T beta); +// Batched gemm +template +void batched_gemm(const platform::DeviceContext& context, + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, + const int M, const int N, const int K, const T alpha, + const T* A, const T* B, const T beta, T* C, + const int batchCount, const int strideA, const int strideB); + template struct SetConstant { void operator()(const platform::DeviceContext& context, diff --git a/paddle/operators/math/matmul.h b/paddle/operators/math/matmul.h new file mode 100644 index 0000000000000000000000000000000000000000..6ba9a0ba9a70bd938f9362179990ab68fa3186ba --- /dev/null +++ b/paddle/operators/math/matmul.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +// Implements the logic of numpy matmul: +// https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html +// +// but allowing also for a, b to be transposed +// +// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported +// yet. +template +class MatMulFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, T alpha, + framework::Tensor* out, T beta) { + auto dim_a = a.dims(); + auto dim_b = b.dims(); + + PADDLE_ENFORCE(a.place() == b.place() && b.place() == out->place(), + "Tensors must all be in the same place."); + PADDLE_ENFORCE_GE(dim_a.size(), 1, + "Input tensor a must be at least 1-dimensional."); + PADDLE_ENFORCE_GE(dim_b.size(), 1, + "Input tensor b must be at least 1-dimensional."); + PADDLE_ENFORCE_LE(dim_a.size(), 3, + "Input tensor a must be at most 3-dimensional."); + PADDLE_ENFORCE_LE(dim_b.size(), 3, + "Input tensor b must be at most 3-dimensional."); + + int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0, + strideA = 0, strideB = 0; + + switch (dim_a.size()) { + case 1: + // similar to np.matmul: + // prepend dimension 1 (no transpose) or append dimension 1 (transpose) + M = trans_a ? dim_a[0] : 1; + kA = trans_a ? 1 : dim_a[0]; + break; + case 2: + M = trans_a ? dim_a[1] : dim_a[0]; + kA = trans_a ? dim_a[0] : dim_a[1]; + break; + case 3: + batchCountA = dim_a[0]; + M = trans_a ? dim_a[2] : dim_a[1]; + kA = trans_a ? dim_a[1] : dim_a[2]; + strideA = M * kA; + break; + default: + assert(false); + } + + switch (dim_b.size()) { + case 1: + // similar to np.matmul: + // append dimension 1 (no transpose) or prepend dimension 1 (transpose) + kB = trans_b ? 1 : dim_b[0]; + N = trans_b ? dim_b[0] : 1; + break; + case 2: + kB = trans_b ? dim_b[1] : dim_b[0]; + N = trans_b ? dim_b[0] : dim_b[1]; + break; + case 3: + batchCountB = dim_b[0]; + kB = trans_b ? dim_b[2] : dim_b[1]; + N = trans_b ? dim_b[1] : dim_b[2]; + strideB = kB * N; + break; + default: + assert(false); + } + + PADDLE_ENFORCE_EQ( + kA, kB, + "First matrix's width must be equal with second matrix's height."); + if (batchCountA && batchCountB) { + PADDLE_ENFORCE_EQ( + batchCountA, batchCountB, + "When input tensors a and b are both batched, they must have the " + "same batch dimension."); + } + int batchCount = std::max(batchCountA, batchCountB); + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + if (!batchCount) { + // regular matrix multiplication + gemm(context, transA, transB, M, N, kA, alpha, a.data(), + b.data(), beta, out->data()); + } else { + // batched matrix multiplication + batched_gemm(context, transA, transB, M, N, kA, alpha, + a.data(), b.data(), beta, out->data(), + batchCount, strideA, strideB); + } + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/matmul_op.cc b/paddle/operators/matmul_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ecbee3b413617e3a5523d9a32e72bc08bd316c5 --- /dev/null +++ b/paddle/operators/matmul_op.cc @@ -0,0 +1,208 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/matmul_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class MatMulOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { + PADDLE_ENFORCE(context->HasInput("X"), + "Input(X) of MatMulOp should not be null."); + PADDLE_ENFORCE(context->HasInput("Y"), + "Input(Y) of MatMulOp should not be null."); + PADDLE_ENFORCE(context->HasOutput("Out"), + "Output(Out) of MatMulOp should not be null."); + + auto dim_x = context->GetInputDim("X"); + auto dim_y = context->GetInputDim("Y"); + bool transpose_x = context->Attrs().Get("transpose_X"); + bool transpose_y = context->Attrs().Get("transpose_Y"); + + PADDLE_ENFORCE_GE(dim_x.size(), 1, + "Input tensor X must be at least 1-dimensional."); + PADDLE_ENFORCE_GE(dim_y.size(), 1, + "Input tensor Y must be at least 1-dimensional."); + PADDLE_ENFORCE_LE(dim_x.size(), 3, + "Input tensor X must be at most 3-dimensional."); + PADDLE_ENFORCE_LE(dim_y.size(), 3, + "Input tensor Y must be at most 3-dimensional."); + + int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; + bool remove_initial_dim = false, remove_final_dim = false; + + switch (dim_x.size()) { + case 1: + if (transpose_x) { + M = dim_x[0]; + KX = 1; + } else { + M = 1; + KX = dim_x[0]; + remove_initial_dim = true; + } + break; + case 2: + M = transpose_x ? dim_x[1] : dim_x[0]; + KX = transpose_x ? dim_x[0] : dim_x[1]; + break; + case 3: + batchCountX = dim_x[0]; + M = transpose_x ? dim_x[2] : dim_x[1]; + KX = transpose_x ? dim_x[1] : dim_x[2]; + break; + default: + assert(false); + } + + switch (dim_y.size()) { + case 1: + if (transpose_y) { + N = dim_y[0]; + KY = 1; + } else { + N = 1; + KY = dim_y[0]; + remove_final_dim = true; + } + break; + case 2: + KY = transpose_y ? dim_y[1] : dim_y[0]; + N = transpose_y ? dim_y[0] : dim_y[1]; + break; + case 3: + batchCountY = dim_y[0]; + KY = transpose_y ? dim_y[2] : dim_y[1]; + N = transpose_y ? dim_y[1] : dim_y[2]; + break; + default: + assert(false); + } + + PADDLE_ENFORCE_EQ( + KX, KY, + "First matrix's width must be equal with second matrix's height."); + if (batchCountX && batchCountY) { + PADDLE_ENFORCE_EQ( + batchCountX, batchCountY, + "When Input(X) and Input(Y) are both three dimensional, they " + "must have the same batch dimension."); + } + int batchCount = std::max(batchCountX, batchCountY); + + std::vector dim_out; + if (batchCount) { + dim_out.push_back(batchCount); + } + if (!remove_initial_dim) { + dim_out.push_back(M); + } + if (!remove_final_dim) { + dim_out.push_back(N); + } + if (dim_out.size() == 0) { + // We don't support 0-dimensional Tensors (scalars), so instead + // treat the output as a Tensor of shape (1, ) in this case. + dim_out.push_back(1); + } + context->SetOutputDim("Out", framework::make_ddim(dim_out)); + context->ShareLoD("X", /*->*/ "Out"); + } +}; + +class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MatMulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of MatMul op"); + AddInput("Y", "The second input of MatMul op"); + AddOutput("Out", "The output of MatMul op"); + AddAttr("transpose_X", + R"DOC(If true, use the transpose of `X`. + )DOC") + .SetDefault(false); + AddAttr("transpose_Y", + R"DOC(If true, use the transpose of `Y`. + )DOC") + .SetDefault(false); + AddComment(R"DOC( +The MatMul operator is used to perform (batched) matrix multiplication +over the last two dimensions of the input tensors `X` and `Y`. + +If a transpose flag is specified, the last two dimensions of the +tensor are transposed. If the tensor is rank-1 of shape [D], then +for `X` it is treated as [1, D] in nontransposed form and as [D, 1] +in transposed form, whereas for `Y` it is the opposite: It is treated +as [D, 1] in nontransposed form and as [1, D] in transposed form. + +Examples without transpose: +- X: [K], Y: [K] => Out: [1] +- X: [K], Y: [K, N] => Out: [N] +- X: [B, M, K], Y: [K] => Out: [B, M] +- X: [M, K], Y: [B, K, N] => Out: [B, M, N] +- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] + +The behavior is designed to be similar to the `numpy.matmul` function. +The differences are: +- Currently only rank 1 to rank 3 input tensors are supported. +- We add `transpose_X` and `transpose_Y` flags. + +Both the input `X` and `Y` can carry the LoD (Level of Details) information, +or not. But the output only shares the LoD with input `X`. +)DOC"); + } +}; + +class MatMulOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { + PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = context->GetInputDim("X"); + auto y_dims = context->GetInputDim("Y"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (context->HasOutput(x_grad_name)) { + context->SetOutputDim(x_grad_name, x_dims); + } + if (context->HasOutput(y_grad_name)) { + context->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(matmul, ops::MatMulOp, ops::MatMulOpMaker, matmul_grad, + ops::MatMulOpGrad); +REGISTER_OP_CPU_KERNEL(matmul, + ops::MatMulKernel); +REGISTER_OP_CPU_KERNEL( + matmul_grad, ops::MatMulGradKernel); diff --git a/paddle/operators/matmul_op.cu b/paddle/operators/matmul_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b7e66382f00445b087e14103e7a148d450b37405 --- /dev/null +++ b/paddle/operators/matmul_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/matmul_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(matmul, + ops::MatMulKernel); +REGISTER_OP_GPU_KERNEL( + matmul_grad, ops::MatMulGradKernel); diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8ae54e1eec33c4bce563f697bafbdc68f97ab746 --- /dev/null +++ b/paddle/operators/matmul_op.h @@ -0,0 +1,228 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/matmul.h" +#include "paddle/operators/transpose_op.h" + +namespace paddle { +namespace operators { +namespace matmul_detail { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; +using framework::make_ddim; +using framework::vectorize; + +template +class MatMulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor& x = *context.Input("X"); + const Tensor& y = *context.Input("Y"); + Tensor* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + bool transpose_x = context.Attr("transpose_X"); + bool transpose_y = context.Attr("transpose_Y"); + + math::MatMulFunctor()(context.device_context(), x, transpose_x, y, + transpose_y, T(1), out, T(0)); + } +}; + +template +inline Tensor Reshape(const Tensor& input, const DDim& dims) { + Tensor output; + output.ShareDataWith(input); + output.Resize(dims); + return output; +} + +// Reshape a rank-3 tensor from P x M x N to (P * M) x N. +// Identity op if the tensor is not of rank 3. +template +Tensor CombineBatchAndM(const Tensor& input) { + Tensor output; + output.ShareDataWith(input); + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + std::vector out_dims = {in_dims[0] * in_dims[1], in_dims[2]}; + output.Resize(make_ddim(out_dims)); + } + return output; +} + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +Tensor CombineBatchAndN(const framework::ExecutionContext& context, + const Tensor& input) { + Tensor output; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize(in_dims); + output.mutable_data(context.GetPlace()); + EigenTranspose(context, input, output, {1, 0, 2}); + std::vector out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; + output.Resize(make_ddim(out_dims)); + } else { + output.ShareDataWith(input); + } + return output; +} + +// Using dimensional constraints on matrix multiplication, it is +// straight-forward to check the following table for when X and Y +// are both matrices. +// +// transpose_X | False | True | False | True +// transpose_Y | False | False | True | True +// -----------+----------+----------+----------+----------- +// dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T +// dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T +// +// When X is a vector of size K, we treat it instead as a matrix of shape +// (1, K). Similarly, when Y is a vector of size K, we treat it instead as +// a matrix of shape (K, 1). +// +// When X and Y are both 3-dimensional tensors, then the first dimension +// the batch dimension can be ignored and the exact same formulas apply +// as for two matrices. +// +// Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end +// up with formulas like +// +// dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj} +// +// To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N +// to X: (P * M) x K, dOut: (P * M) x N. +template +class MatMulGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor& x = *context.Input("X"); + const Tensor& y = *context.Input("Y"); + const Tensor& dout = *context.Input(framework::GradVarName("Out")); + Tensor* dx = context.Output(framework::GradVarName("X")); + Tensor* dy = context.Output(framework::GradVarName("Y")); + bool transpose_x = context.Attr("transpose_X"); + bool transpose_y = context.Attr("transpose_Y"); + + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + + // If X is a vector, reshape it to a matrix. + if (x_dims.size() == 1) { + x_dims.insert(x_dims.begin(), 1); + } + + // If Y is a vector, reshape it to a matrix. + if (y_dims.size() == 1) { + y_dims.push_back(1); + } + + // Fix the dOut dimensions. + int M = 0, N = 0, batchCountX = 0, batchCountY = 0; + + switch (x_dims.size()) { + case 2: + M = transpose_x ? x_dims[1] : x_dims[0]; + break; + case 3: + batchCountX = x_dims[0]; + M = transpose_x ? x_dims[2] : x_dims[1]; + break; + default: + assert(false); + } + + switch (y_dims.size()) { + case 2: + N = transpose_y ? y_dims[0] : y_dims[1]; + break; + case 3: + batchCountY = y_dims[0]; + N = transpose_y ? y_dims[1] : y_dims[2]; + break; + default: + assert(false); + } + if (batchCountX && batchCountY) { + PADDLE_ENFORCE_EQ( + batchCountX, batchCountY, + "When Input(X) and Input(Y) are both three dimensional, they " + "must have the same batch dimension."); + } + int batchCount = std::max(batchCountX, batchCountY); + std::vector dout_dims = {M, N}; + if (batchCount) { + dout_dims.insert(dout_dims.begin(), batchCount); + } + Tensor X = Reshape(x, make_ddim(x_dims)); + Tensor Y = Reshape(y, make_ddim(y_dims)); + Tensor dOut = Reshape(dout, make_ddim(dout_dims)); + + if (dx) { + dx->mutable_data(context.GetPlace()); + const Tensor& dOut_for_dX = + (x_dims.size() == 2 && y_dims.size() == 3) + ? CombineBatchAndN(context, dOut) + : dOut; + if (x_dims.size() == 2 && y_dims.size() == 3) { + Y = transpose_y ? CombineBatchAndM(Y) + : CombineBatchAndN(context, Y); + } + if (transpose_x) { + math::MatMulFunctor()(context.device_context(), Y, + transpose_y, dOut_for_dX, transpose_x, + T(1), dx, T(0)); + } else { + math::MatMulFunctor()(context.device_context(), dOut_for_dX, + transpose_x, Y, !transpose_y, T(1), dx, + T(0)); + } + } + + if (dy) { + dy->mutable_data(context.GetPlace()); + const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3) + ? CombineBatchAndM(dOut) + : dOut; + if (y_dims.size() == 2 && x_dims.size() == 3) { + X = transpose_x ? CombineBatchAndN(context, X) + : CombineBatchAndM(X); + dOut = CombineBatchAndM(dOut); + } + if (transpose_y) { + math::MatMulFunctor()(context.device_context(), dOut_for_dY, + transpose_y, X, transpose_x, T(1), dy, + T(0)); + } else { + math::MatMulFunctor()(context.device_context(), X, + !transpose_x, dOut_for_dY, transpose_y, + T(1), dy, T(0)); + } + } + } +}; +} // namespace matmul_detail + +using matmul_detail::MatMulKernel; +using matmul_detail::MatMulGradKernel; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index 9d8343c0b5e200b390ccda760f09816959952e9d..6b64539b0a9a4d535a53447fbcc0e458f3ac9129 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -77,6 +77,10 @@ extern void *cublas_dso_handle; __macro(cublasDgemmBatched); \ __macro(cublasCgemmBatched); \ __macro(cublasZgemmBatched); \ + __macro(cublasSgemmStridedBatched); \ + __macro(cublasDgemmStridedBatched); \ + __macro(cublasCgemmStridedBatched); \ + __macro(cublasZgemmStridedBatched); \ __macro(cublasSgetrfBatched); \ __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ diff --git a/python/paddle/v2/framework/tests/test_matmul_op.py b/python/paddle/v2/framework/tests/test_matmul_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d51572c8ab7c44fa0c6e83e50b56f05780530c61 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_matmul_op.py @@ -0,0 +1,119 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y): + BATCH_SIZE = 2 + M = 3 + N = 4 + K = 5 + if (dim_X == 1 and transpose_X) or (dim_Y == 1 and transpose_Y): + K = 1 + if dim_X == 1: + if transpose_X: + shape_X = [M] + else: + shape_X = [K] + if dim_Y == 1: + if transpose_Y: + shape_Y = [N] + else: + shape_Y = [K] + if dim_X >= 2: + if transpose_X: + shape_X = [K, M] + else: + shape_X = [M, K] + if dim_X == 3: + shape_X = [BATCH_SIZE] + shape_X + if dim_Y >= 2: + if transpose_Y: + shape_Y = [N, K] + else: + shape_Y = [K, N] + if dim_Y == 3: + shape_Y = [BATCH_SIZE] + shape_Y + return shape_X, shape_Y + + +def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): + """Reference forward implementation using np.matmul.""" + # np.matmul does not support the transpose flags, so we manually + # transpose X and Y appropriately. + if transpose_X: + if X.ndim == 1: + X = X.reshape((X.size, 1)) + elif X.ndim == 2: + X = X.T + elif X.ndim == 3: + X = np.transpose(X, (0, 2, 1)) + else: + raise ValueError('X must have between 1 and 3 dimensions') + if transpose_Y: + if Y.ndim == 1: + Y = Y.reshape((1, Y.size)) + elif Y.ndim == 2: + Y = Y.T + elif Y.ndim == 3: + Y = np.transpose(Y, (0, 2, 1)) + else: + raise ValueError('Y must have between 1 and 3 dimensions') + Out = np.matmul(X, Y) + if not Out.shape: + # We do not support 0-dimensional Tensors (scalars). So where + # np.matmul outputs a scalar, we must convert to a Tensor of + # shape (1, ) instead. + # Everywhere else, we are compatible with np.matmul. + Out = np.array([Out], dtype="float32") + return Out + + +class Generator(object): + def setUp(self): + self.op_type = "matmul" + X = np.random.random(self.shape_X).astype("float32") + Y = np.random.random(self.shape_Y).astype("float32") + Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y) + self.inputs = {'X': X, 'Y': Y} + self.attrs = { + 'transpose_X': self.transpose_X, + 'transpose_Y': self.transpose_Y + } + self.outputs = {'Out': Out} + + def test_check_output(self): + self.check_output(atol=1e-2) + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5) + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X")) + + def test_check_grad_ignore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) + + +# Generate test cases for all possibilities +for dim_X in [1, 2, 3]: + for dim_Y in [1, 2, 3]: + for transpose_X in [False, True]: + for transpose_Y in [False, True]: + test_name = ( + 'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( + dim_X, dim_Y, transpose_X, transpose_Y)) + shape_X, shape_Y = generate_compatible_shapes( + dim_X, dim_Y, transpose_X, transpose_Y) + test_class = type(test_name, (Generator, OpTest), { + 'shape_X': shape_X, + 'shape_Y': shape_Y, + 'transpose_X': transpose_X, + 'transpose_Y': transpose_Y, + }) + globals()[test_name] = test_class + +if __name__ == "__main__": + unittest.main()