提交 16489827 编写于 作者: M Markus Kliegl 提交者: GitHub

MatMul operator (#4856)

* initial matmul operator

Similar to np.matmul, but also has transpose_X and transpose_Y flags,
and only supports tensors from rank 1 to 3 inclusive.

For GPU, uses cublas?gemmStridedBatched. For CPU, uses
cblas_?gemm_batch if available via MKL; otherwise a simple serial
implementation that loops over the batch dimension is employed for now.
上级 fd96914d
...@@ -130,6 +130,87 @@ void matmul<platform::CPUPlace, double>( ...@@ -130,6 +130,87 @@ void matmul<platform::CPUPlace, double>(
matrix_b.data<double>(), beta, matrix_out->data<double>()); matrix_b.data<double>(), beta, matrix_out->data<double>());
} }
#ifdef PADDLE_USE_MKLML
// Use cblas_{s,d}gemm_batched if available: Run with 1 group of size batchSize.
template <>
void batched_gemm<platform::CPUPlace, float>(
const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
auto a_array = std::vector<const float*>(batchCount);
auto b_array = std::vector<const float*>(batchCount);
auto c_array = std::vector<float*>(batchCount);
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA];
b_array[k] = &B[k * strideB];
c_array[k] = &C[k * M * N];
}
cblas_sgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
a_array.data(), &lda, b_array.data(), &ldb, &beta,
c_array.data(), &ldc, 1 /* group_count */, &batchCount);
}
template <>
void batched_gemm<platform::CPUPlace, double>(
const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
auto a_array = std::vector<const double*>(batchCount);
auto b_array = std::vector<const double*>(batchCount);
auto c_array = std::vector<double*>(batchCount);
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA];
b_array[k] = &B[k * strideB];
c_array[k] = &C[k * M * N];
}
cblas_dgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
a_array.data(), &lda, b_array.data(), &ldb, &beta,
c_array.data(), &ldc, 1 /* group_count */, &batchCount);
}
#else
// The below is a naive but correct serial implementation that just loops
// over the batch dimension. This is a fallback for when the batched gemm
// functions of Intel MKL are not available. In the future, this computation
// should be parallelized.
template <>
void batched_gemm<platform::CPUPlace, float>(
const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) {
for (int k = 0; k < batchCount; ++k) {
const float* Ak = &A[k * strideA];
const float* Bk = &B[k * strideB];
float* Ck = &C[k * M * N];
gemm<platform::CPUPlace, float>(context, transA, transB, M, N, K, alpha, Ak,
Bk, beta, Ck);
}
}
template <>
void batched_gemm<platform::CPUPlace, double>(
const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) {
for (int k = 0; k < batchCount; ++k) {
const double* Ak = &A[k * strideA];
const double* Bk = &B[k * strideB];
double* Ck = &C[k * M * N];
gemm<platform::CPUPlace, double>(context, transA, transB, M, N, K, alpha,
Ak, Bk, beta, Ck);
}
}
#endif
template struct SetConstant<platform::CPUPlace, float>; template struct SetConstant<platform::CPUPlace, float>;
} // namespace math } // namespace math
......
...@@ -155,6 +155,54 @@ void matmul<platform::GPUPlace, double>( ...@@ -155,6 +155,54 @@ void matmul<platform::GPUPlace, double>(
matrix_b.data<double>(), beta, matrix_out->data<double>()); matrix_b.data<double>(), beta, matrix_out->data<double>());
} }
template <>
void batched_gemm<platform::GPUPlace, float>(
const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, strideB, A, lda, strideA,
&beta, C, ldc, strideC, batchCount));
}
template <>
void batched_gemm<platform::GPUPlace, double>(
const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, strideB, A, lda, strideA,
&beta, C, ldc, strideC, batchCount));
}
template struct SetConstant<platform::GPUPlace, float>; template struct SetConstant<platform::GPUPlace, float>;
} // namespace math } // namespace math
......
...@@ -63,7 +63,7 @@ namespace math { ...@@ -63,7 +63,7 @@ namespace math {
// Support continuous memory now // Support continuous memory now
// If transA = N, and transB = N // If transA = N, and transB = N
// Then matrixA: M * K, matrixB: K * N matrixC : M * N // Then matrixA: M * K, matrixB: K * N, matrixC : M * N
// For more detailed info, please refer to // For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename Place, typename T> template <typename Place, typename T>
...@@ -85,6 +85,14 @@ void matmul(const platform::DeviceContext& context, ...@@ -85,6 +85,14 @@ void matmul(const platform::DeviceContext& context,
const framework::Tensor& matrix_b, bool trans_b, T alpha, const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta); framework::Tensor* matrix_out, T beta);
// Batched gemm
template <typename Place, typename T>
void batched_gemm(const platform::DeviceContext& context,
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha,
const T* A, const T* B, const T beta, T* C,
const int batchCount, const int strideA, const int strideB);
template <typename Place, typename T> template <typename Place, typename T>
struct SetConstant { struct SetConstant {
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
// Implements the logic of numpy matmul:
// https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
//
// but allowing also for a, b to be transposed
//
// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported
// yet.
template <typename Place, typename T>
class MatMulFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b, T alpha,
framework::Tensor* out, T beta) {
auto dim_a = a.dims();
auto dim_b = b.dims();
PADDLE_ENFORCE(a.place() == b.place() && b.place() == out->place(),
"Tensors must all be in the same place.");
PADDLE_ENFORCE_GE(dim_a.size(), 1,
"Input tensor a must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_b.size(), 1,
"Input tensor b must be at least 1-dimensional.");
PADDLE_ENFORCE_LE(dim_a.size(), 3,
"Input tensor a must be at most 3-dimensional.");
PADDLE_ENFORCE_LE(dim_b.size(), 3,
"Input tensor b must be at most 3-dimensional.");
int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0,
strideA = 0, strideB = 0;
switch (dim_a.size()) {
case 1:
// similar to np.matmul:
// prepend dimension 1 (no transpose) or append dimension 1 (transpose)
M = trans_a ? dim_a[0] : 1;
kA = trans_a ? 1 : dim_a[0];
break;
case 2:
M = trans_a ? dim_a[1] : dim_a[0];
kA = trans_a ? dim_a[0] : dim_a[1];
break;
case 3:
batchCountA = dim_a[0];
M = trans_a ? dim_a[2] : dim_a[1];
kA = trans_a ? dim_a[1] : dim_a[2];
strideA = M * kA;
break;
default:
assert(false);
}
switch (dim_b.size()) {
case 1:
// similar to np.matmul:
// append dimension 1 (no transpose) or prepend dimension 1 (transpose)
kB = trans_b ? 1 : dim_b[0];
N = trans_b ? dim_b[0] : 1;
break;
case 2:
kB = trans_b ? dim_b[1] : dim_b[0];
N = trans_b ? dim_b[0] : dim_b[1];
break;
case 3:
batchCountB = dim_b[0];
kB = trans_b ? dim_b[2] : dim_b[1];
N = trans_b ? dim_b[1] : dim_b[2];
strideB = kB * N;
break;
default:
assert(false);
}
PADDLE_ENFORCE_EQ(
kA, kB,
"First matrix's width must be equal with second matrix's height.");
if (batchCountA && batchCountB) {
PADDLE_ENFORCE_EQ(
batchCountA, batchCountB,
"When input tensors a and b are both batched, they must have the "
"same batch dimension.");
}
int batchCount = std::max(batchCountA, batchCountB);
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
if (!batchCount) {
// regular matrix multiplication
gemm<Place, T>(context, transA, transB, M, N, kA, alpha, a.data<T>(),
b.data<T>(), beta, out->data<T>());
} else {
// batched matrix multiplication
batched_gemm<Place, T>(context, transA, transB, M, N, kA, alpha,
a.data<T>(), b.data<T>(), beta, out->data<T>(),
batchCount, strideA, strideB);
}
}
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/matmul_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class MatMulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of MatMulOp should not be null.");
PADDLE_ENFORCE(context->HasInput("Y"),
"Input(Y) of MatMulOp should not be null.");
PADDLE_ENFORCE(context->HasOutput("Out"),
"Output(Out) of MatMulOp should not be null.");
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
bool transpose_x = context->Attrs().Get<bool>("transpose_X");
bool transpose_y = context->Attrs().Get<bool>("transpose_Y");
PADDLE_ENFORCE_GE(dim_x.size(), 1,
"Input tensor X must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_y.size(), 1,
"Input tensor Y must be at least 1-dimensional.");
PADDLE_ENFORCE_LE(dim_x.size(), 3,
"Input tensor X must be at most 3-dimensional.");
PADDLE_ENFORCE_LE(dim_y.size(), 3,
"Input tensor Y must be at most 3-dimensional.");
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0;
bool remove_initial_dim = false, remove_final_dim = false;
switch (dim_x.size()) {
case 1:
if (transpose_x) {
M = dim_x[0];
KX = 1;
} else {
M = 1;
KX = dim_x[0];
remove_initial_dim = true;
}
break;
case 2:
M = transpose_x ? dim_x[1] : dim_x[0];
KX = transpose_x ? dim_x[0] : dim_x[1];
break;
case 3:
batchCountX = dim_x[0];
M = transpose_x ? dim_x[2] : dim_x[1];
KX = transpose_x ? dim_x[1] : dim_x[2];
break;
default:
assert(false);
}
switch (dim_y.size()) {
case 1:
if (transpose_y) {
N = dim_y[0];
KY = 1;
} else {
N = 1;
KY = dim_y[0];
remove_final_dim = true;
}
break;
case 2:
KY = transpose_y ? dim_y[1] : dim_y[0];
N = transpose_y ? dim_y[0] : dim_y[1];
break;
case 3:
batchCountY = dim_y[0];
KY = transpose_y ? dim_y[2] : dim_y[1];
N = transpose_y ? dim_y[1] : dim_y[2];
break;
default:
assert(false);
}
PADDLE_ENFORCE_EQ(
KX, KY,
"First matrix's width must be equal with second matrix's height.");
if (batchCountX && batchCountY) {
PADDLE_ENFORCE_EQ(
batchCountX, batchCountY,
"When Input(X) and Input(Y) are both three dimensional, they "
"must have the same batch dimension.");
}
int batchCount = std::max(batchCountX, batchCountY);
std::vector<int64_t> dim_out;
if (batchCount) {
dim_out.push_back(batchCount);
}
if (!remove_initial_dim) {
dim_out.push_back(M);
}
if (!remove_final_dim) {
dim_out.push_back(N);
}
if (dim_out.size() == 0) {
// We don't support 0-dimensional Tensors (scalars), so instead
// treat the output as a Tensor of shape (1, ) in this case.
dim_out.push_back(1);
}
context->SetOutputDim("Out", framework::make_ddim(dim_out));
context->ShareLoD("X", /*->*/ "Out");
}
};
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MatMulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of MatMul op");
AddInput("Y", "The second input of MatMul op");
AddOutput("Out", "The output of MatMul op");
AddAttr<bool>("transpose_X",
R"DOC(If true, use the transpose of `X`.
)DOC")
.SetDefault(false);
AddAttr<bool>("transpose_Y",
R"DOC(If true, use the transpose of `Y`.
)DOC")
.SetDefault(false);
AddComment(R"DOC(
The MatMul operator is used to perform (batched) matrix multiplication
over the last two dimensions of the input tensors `X` and `Y`.
If a transpose flag is specified, the last two dimensions of the
tensor are transposed. If the tensor is rank-1 of shape [D], then
for `X` it is treated as [1, D] in nontransposed form and as [D, 1]
in transposed form, whereas for `Y` it is the opposite: It is treated
as [D, 1] in nontransposed form and as [1, D] in transposed form.
Examples without transpose:
- X: [K], Y: [K] => Out: [1]
- X: [K], Y: [K, N] => Out: [N]
- X: [B, M, K], Y: [K] => Out: [B, M]
- X: [M, K], Y: [B, K, N] => Out: [B, M, N]
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
The behavior is designed to be similar to the `numpy.matmul` function.
The differences are:
- Currently only rank 1 to rank 3 input tensors are supported.
- We add `transpose_X` and `transpose_Y` flags.
Both the input `X` and `Y` can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD with input `X`.
)DOC");
}
};
class MatMulOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = context->GetInputDim("X");
auto y_dims = context->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (context->HasOutput(x_grad_name)) {
context->SetOutputDim(x_grad_name, x_dims);
}
if (context->HasOutput(y_grad_name)) {
context->SetOutputDim(y_grad_name, y_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(matmul, ops::MatMulOp, ops::MatMulOpMaker, matmul_grad,
ops::MatMulOpGrad);
REGISTER_OP_CPU_KERNEL(matmul,
ops::MatMulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
matmul_grad, ops::MatMulGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/matmul_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(matmul,
ops::MatMulKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
matmul_grad, ops::MatMulGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/matmul.h"
#include "paddle/operators/transpose_op.h"
namespace paddle {
namespace operators {
namespace matmul_detail {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
using framework::make_ddim;
using framework::vectorize;
template <typename Place, typename T>
class MatMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
Tensor* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y");
math::MatMulFunctor<Place, T>()(context.device_context(), x, transpose_x, y,
transpose_y, T(1), out, T(0));
}
};
template <typename T>
inline Tensor Reshape(const Tensor& input, const DDim& dims) {
Tensor output;
output.ShareDataWith<T>(input);
output.Resize(dims);
return output;
}
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
template <typename T>
Tensor CombineBatchAndM(const Tensor& input) {
Tensor output;
output.ShareDataWith<T>(input);
auto in_dims = input.dims();
if (in_dims.size() == 3) {
std::vector<int64_t> out_dims = {in_dims[0] * in_dims[1], in_dims[2]};
output.Resize(make_ddim(out_dims));
}
return output;
}
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename Place, typename T>
Tensor CombineBatchAndN(const framework::ExecutionContext& context,
const Tensor& input) {
Tensor output;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize(in_dims);
output.mutable_data<T>(context.GetPlace());
EigenTranspose<Place, T, 3>(context, input, output, {1, 0, 2});
std::vector<int64_t> out_dims = {in_dims[1], in_dims[0] * in_dims[2]};
output.Resize(make_ddim(out_dims));
} else {
output.ShareDataWith<T>(input);
}
return output;
}
// Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y
// are both matrices.
//
// transpose_X | False | True | False | True
// transpose_Y | False | False | True | True
// -----------+----------+----------+----------+-----------
// dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T
// dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T
//
// When X is a vector of size K, we treat it instead as a matrix of shape
// (1, K). Similarly, when Y is a vector of size K, we treat it instead as
// a matrix of shape (K, 1).
//
// When X and Y are both 3-dimensional tensors, then the first dimension
// the batch dimension can be ignored and the exact same formulas apply
// as for two matrices.
//
// Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end
// up with formulas like
//
// dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj}
//
// To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N
// to X: (P * M) x K, dOut: (P * M) x N.
template <typename Place, typename T>
class MatMulGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
const Tensor& dout = *context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* dx = context.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = context.Output<Tensor>(framework::GradVarName("Y"));
bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y");
std::vector<int64_t> x_dims = vectorize(x.dims());
std::vector<int64_t> y_dims = vectorize(y.dims());
// If X is a vector, reshape it to a matrix.
if (x_dims.size() == 1) {
x_dims.insert(x_dims.begin(), 1);
}
// If Y is a vector, reshape it to a matrix.
if (y_dims.size() == 1) {
y_dims.push_back(1);
}
// Fix the dOut dimensions.
int M = 0, N = 0, batchCountX = 0, batchCountY = 0;
switch (x_dims.size()) {
case 2:
M = transpose_x ? x_dims[1] : x_dims[0];
break;
case 3:
batchCountX = x_dims[0];
M = transpose_x ? x_dims[2] : x_dims[1];
break;
default:
assert(false);
}
switch (y_dims.size()) {
case 2:
N = transpose_y ? y_dims[0] : y_dims[1];
break;
case 3:
batchCountY = y_dims[0];
N = transpose_y ? y_dims[1] : y_dims[2];
break;
default:
assert(false);
}
if (batchCountX && batchCountY) {
PADDLE_ENFORCE_EQ(
batchCountX, batchCountY,
"When Input(X) and Input(Y) are both three dimensional, they "
"must have the same batch dimension.");
}
int batchCount = std::max(batchCountX, batchCountY);
std::vector<int64_t> dout_dims = {M, N};
if (batchCount) {
dout_dims.insert(dout_dims.begin(), batchCount);
}
Tensor X = Reshape<T>(x, make_ddim(x_dims));
Tensor Y = Reshape<T>(y, make_ddim(y_dims));
Tensor dOut = Reshape<T>(dout, make_ddim(dout_dims));
if (dx) {
dx->mutable_data<T>(context.GetPlace());
const Tensor& dOut_for_dX =
(x_dims.size() == 2 && y_dims.size() == 3)
? CombineBatchAndN<Place, T>(context, dOut)
: dOut;
if (x_dims.size() == 2 && y_dims.size() == 3) {
Y = transpose_y ? CombineBatchAndM<T>(Y)
: CombineBatchAndN<Place, T>(context, Y);
}
if (transpose_x) {
math::MatMulFunctor<Place, T>()(context.device_context(), Y,
transpose_y, dOut_for_dX, transpose_x,
T(1), dx, T(0));
} else {
math::MatMulFunctor<Place, T>()(context.device_context(), dOut_for_dX,
transpose_x, Y, !transpose_y, T(1), dx,
T(0));
}
}
if (dy) {
dy->mutable_data<T>(context.GetPlace());
const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3)
? CombineBatchAndM<T>(dOut)
: dOut;
if (y_dims.size() == 2 && x_dims.size() == 3) {
X = transpose_x ? CombineBatchAndN<Place, T>(context, X)
: CombineBatchAndM<T>(X);
dOut = CombineBatchAndM<T>(dOut);
}
if (transpose_y) {
math::MatMulFunctor<Place, T>()(context.device_context(), dOut_for_dY,
transpose_y, X, transpose_x, T(1), dy,
T(0));
} else {
math::MatMulFunctor<Place, T>()(context.device_context(), X,
!transpose_x, dOut_for_dY, transpose_y,
T(1), dy, T(0));
}
}
}
};
} // namespace matmul_detail
using matmul_detail::MatMulKernel;
using matmul_detail::MatMulGradKernel;
} // namespace operators
} // namespace paddle
...@@ -77,6 +77,10 @@ extern void *cublas_dso_handle; ...@@ -77,6 +77,10 @@ extern void *cublas_dso_handle;
__macro(cublasDgemmBatched); \ __macro(cublasDgemmBatched); \
__macro(cublasCgemmBatched); \ __macro(cublasCgemmBatched); \
__macro(cublasZgemmBatched); \ __macro(cublasZgemmBatched); \
__macro(cublasSgemmStridedBatched); \
__macro(cublasDgemmStridedBatched); \
__macro(cublasCgemmStridedBatched); \
__macro(cublasZgemmStridedBatched); \
__macro(cublasSgetrfBatched); \ __macro(cublasSgetrfBatched); \
__macro(cublasSgetriBatched); \ __macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \ __macro(cublasDgetrfBatched); \
......
import unittest
import numpy as np
from op_test import OpTest
def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y):
BATCH_SIZE = 2
M = 3
N = 4
K = 5
if (dim_X == 1 and transpose_X) or (dim_Y == 1 and transpose_Y):
K = 1
if dim_X == 1:
if transpose_X:
shape_X = [M]
else:
shape_X = [K]
if dim_Y == 1:
if transpose_Y:
shape_Y = [N]
else:
shape_Y = [K]
if dim_X >= 2:
if transpose_X:
shape_X = [K, M]
else:
shape_X = [M, K]
if dim_X == 3:
shape_X = [BATCH_SIZE] + shape_X
if dim_Y >= 2:
if transpose_Y:
shape_Y = [N, K]
else:
shape_Y = [K, N]
if dim_Y == 3:
shape_Y = [BATCH_SIZE] + shape_Y
return shape_X, shape_Y
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
"""Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately.
if transpose_X:
if X.ndim == 1:
X = X.reshape((X.size, 1))
elif X.ndim == 2:
X = X.T
elif X.ndim == 3:
X = np.transpose(X, (0, 2, 1))
else:
raise ValueError('X must have between 1 and 3 dimensions')
if transpose_Y:
if Y.ndim == 1:
Y = Y.reshape((1, Y.size))
elif Y.ndim == 2:
Y = Y.T
elif Y.ndim == 3:
Y = np.transpose(Y, (0, 2, 1))
else:
raise ValueError('Y must have between 1 and 3 dimensions')
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out
class Generator(object):
def setUp(self):
self.op_type = "matmul"
X = np.random.random(self.shape_X).astype("float32")
Y = np.random.random(self.shape_Y).astype("float32")
Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y)
self.inputs = {'X': X, 'Y': Y}
self.attrs = {
'transpose_X': self.transpose_X,
'transpose_Y': self.transpose_Y
}
self.outputs = {'Out': Out}
def test_check_output(self):
self.check_output(atol=1e-2)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5)
def test_check_grad_ignore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
def test_check_grad_ignore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
# Generate test cases for all possibilities
for dim_X in [1, 2, 3]:
for dim_Y in [1, 2, 3]:
for transpose_X in [False, True]:
for transpose_Y in [False, True]:
test_name = (
'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
dim_X, dim_Y, transpose_X, transpose_Y))
shape_X, shape_Y = generate_compatible_shapes(
dim_X, dim_Y, transpose_X, transpose_Y)
test_class = type(test_name, (Generator, OpTest), {
'shape_X': shape_X,
'shape_Y': shape_Y,
'transpose_X': transpose_X,
'transpose_Y': transpose_Y,
})
globals()[test_name] = test_class
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册