提交 c6a6d87f 编写于 作者: Y Yu Yang

Rewrite Matmul, make code cleaner

上级 0285a2b9
......@@ -13,10 +13,47 @@
// limitations under the License.
#include "paddle/fluid/operators/math/blas.h"
#include <utility>
namespace paddle {
namespace operators {
namespace math {
// Do nothing. Blas is a header only library.
MatDim GetMatDim(const framework::DDim& dim, int num_flatten_cols, bool trans) {
MatDim retv;
if (num_flatten_cols > 1) {
auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols);
retv.height_ = flatten_dim[0];
retv.width_ = flatten_dim[1];
} else {
if (dim.size() == 1) {
retv.height_ = 1;
retv.width_ = dim[0];
} else if (dim.size() == 2) {
retv.height_ = dim[0];
retv.width_ = dim[1];
} else {
if (dim.size() == 3) {
retv.batch_size_ = dim[0];
retv.height_ = dim[1];
retv.width_ = dim[2];
} else {
auto dim_vec = framework::vectorize(dim);
retv.batch_size_ = 1;
for (size_t i = 0; i < dim_vec.size() - 2; ++i) {
retv.batch_size_ *= dim_vec[i];
retv.height_ = dim_vec[dim_vec.size() - 2];
retv.width_ = dim_vec[dim_vec.size() - 1];
}
}
retv.stride_ = retv.height_ * retv.width_;
}
}
if (trans) {
std::swap(retv.width_, retv.height_);
}
retv.trans_ = trans;
return retv;
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -46,6 +46,17 @@ namespace paddle {
namespace operators {
namespace math {
struct MatDim {
int64_t height_;
int64_t width_;
int64_t stride_{0};
int64_t batch_size_{0};
bool trans_;
};
extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols,
bool trans);
template <typename DeviceContext>
class Blas {
public:
......@@ -90,6 +101,28 @@ class Blas {
int K, T alpha, const T* A, const T* B, T beta, T* C,
int batchCount, int64_t strideA, int64_t strideB) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a,
const framework::Tensor& mat_b, const MatDim& dim_b, T alpha,
framework::Tensor* mat_out, T beta) const {
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_);
}
}
private:
const DeviceContext& context_;
};
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
namespace math {
// Implements the logic of numpy matmul:
// https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
//
// but allowing also for a, b to be transposed
//
// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported
// yet.
template <typename DeviceContext, typename T>
class MatMulFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& a,
bool trans_a, const framework::Tensor& b, bool trans_b,
T alpha, framework::Tensor* out, T beta) {
auto dim_a = a.dims();
auto dim_b = b.dims();
PADDLE_ENFORCE(a.place() == b.place() && b.place() == out->place(),
"Tensors must all be in the same place.");
PADDLE_ENFORCE_GE(dim_a.size(), 1,
"Input tensor a must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_b.size(), 1,
"Input tensor b must be at least 1-dimensional.");
std::vector<int64_t> out_dim;
int64_t batch_count = 1;
if (dim_a.size() > 3) {
PADDLE_ENFORCE(dim_b.size() == dim_a.size(),
"The dimensions of X and Y must be the same, and both of "
"them should be %d-dimensional.",
dim_b.size());
// The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_a.size() - 2; ++j) {
PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j],
"The %d-th dimension of X and Y must be the same.",
j);
out_dim.push_back(dim_a[j]);
batch_count *= dim_a[j];
}
}
int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0,
strideA = 0, strideB = 0;
switch (dim_a.size()) {
case 1:
// similar to np.matmul:
// prepend dimension 1 (no transpose) or append dimension 1 (transpose)
M = trans_a ? dim_a[0] : 1;
kA = trans_a ? 1 : dim_a[0];
break;
case 2:
M = trans_a ? dim_a[1] : dim_a[0];
kA = trans_a ? dim_a[0] : dim_a[1];
break;
case 3:
batchCountA = dim_a[0];
M = trans_a ? dim_a[2] : dim_a[1];
kA = trans_a ? dim_a[1] : dim_a[2];
strideA = M * kA;
break;
default:
batchCountA = batch_count;
size_t mat_s = dim_a.size() - 2;
M = trans_a ? dim_a[mat_s + 1] : dim_a[mat_s];
kA = trans_a ? dim_a[mat_s] : dim_a[mat_s + 1];
strideA = M * kA;
}
switch (dim_b.size()) {
case 1:
// similar to np.matmul:
// append dimension 1 (no transpose) or prepend dimension 1 (transpose)
kB = trans_b ? 1 : dim_b[0];
N = trans_b ? dim_b[0] : 1;
break;
case 2:
kB = trans_b ? dim_b[1] : dim_b[0];
N = trans_b ? dim_b[0] : dim_b[1];
break;
case 3:
batchCountB = dim_b[0];
kB = trans_b ? dim_b[2] : dim_b[1];
N = trans_b ? dim_b[1] : dim_b[2];
strideB = kB * N;
break;
default:
batchCountB = batch_count;
size_t mat_s = dim_b.size() - 2;
kB = trans_b ? dim_b[mat_s + 1] : dim_b[mat_s];
N = trans_b ? dim_b[mat_s] : dim_b[mat_s + 1];
strideB = kB * N;
}
PADDLE_ENFORCE_EQ(
kA, kB,
"First matrix's width must be equal with second matrix's height.");
if (batchCountA && batchCountB) {
PADDLE_ENFORCE_EQ(
batchCountA, batchCountB,
"When input tensors a and b are both batched, they must have the "
"same batch dimension.");
}
int batchCount = std::max(batchCountA, batchCountB);
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
auto blas = GetBlas<DeviceContext, T>(context);
if (!batchCount) {
// regular matrix multiplication
blas.GEMM(transA, transB, M, N, kA, alpha, a.data<T>(), b.data<T>(), beta,
out->data<T>());
} else {
// batched matrix multiplication
blas.BatchedGEMM(transA, transB, M, N, kA, alpha, a.data<T>(),
b.data<T>(), beta, out->data<T>(), batchCount, strideA,
strideB);
}
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -36,121 +36,39 @@ class MatMulOp : public framework::OperatorWithKernel {
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
bool transpose_x = context->Attrs().Get<bool>("transpose_X");
bool transpose_y = context->Attrs().Get<bool>("transpose_Y");
PADDLE_ENFORCE_GE(dim_x.size(), 1,
"Input tensor X must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_y.size(), 1,
"Input tensor Y must be at least 1-dimensional.");
std::vector<int64_t> out_dim;
int64_t batch_count = 1;
if (dim_x.size() > 3) {
PADDLE_ENFORCE_EQ(
dim_y.size(), dim_x.size(),
"The dimensions of X and Y must be the same, and both of "
"them should be %d-dimensional.",
dim_x.size());
// The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_x.size() - 2; ++j) {
PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j],
"The %d-th dimension of X and Y must be the same.",
j);
out_dim.push_back(dim_x[j]);
batch_count *= dim_x[j];
}
}
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0;
bool remove_initial_dim = false, remove_final_dim = false;
switch (dim_x.size()) {
case 1:
if (transpose_x) {
M = dim_x[0];
KX = 1;
} else {
M = 1;
KX = dim_x[0];
remove_initial_dim = true;
}
break;
case 2:
M = transpose_x ? dim_x[1] : dim_x[0];
KX = transpose_x ? dim_x[0] : dim_x[1];
break;
case 3:
batchCountX = dim_x[0];
M = transpose_x ? dim_x[2] : dim_x[1];
KX = transpose_x ? dim_x[1] : dim_x[2];
break;
default:
batchCountX = batch_count;
size_t mat_s = dim_x.size() - 2;
M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s];
KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1];
break;
}
auto mat_dim_x = math::GetMatDim(GetXDim(dim_x), 0,
context->Attrs().Get<bool>("transpose_X"));
auto mat_dim_y = math::GetMatDim(GetYDim(dim_y), 0,
context->Attrs().Get<bool>("transpose_Y"));
switch (dim_y.size()) {
case 1:
if (transpose_y) {
N = dim_y[0];
KY = 1;
} else {
N = 1;
KY = dim_y[0];
remove_final_dim = true;
}
break;
case 2:
KY = transpose_y ? dim_y[1] : dim_y[0];
N = transpose_y ? dim_y[0] : dim_y[1];
break;
case 3:
batchCountY = dim_y[0];
KY = transpose_y ? dim_y[2] : dim_y[1];
N = transpose_y ? dim_y[1] : dim_y[2];
break;
default:
batchCountY = batch_count;
size_t mat_s = dim_y.size() - 2;
KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s];
N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1];
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
std::vector<int64_t> dim_out;
if (mat_dim_x.batch_size_ != 0) {
dim_out = framework::vectorize(dim_x);
dim_out[dim_out.size() - 2] = mat_dim_x.height_;
dim_out[dim_out.size() - 1] = mat_dim_y.width_;
} else if (mat_dim_y.batch_size_ != 0) {
dim_out = framework::vectorize(dim_y);
dim_out[dim_out.size() - 2] = mat_dim_x.height_;
dim_out[dim_out.size() - 1] = mat_dim_y.width_;
} else {
dim_out = {mat_dim_x.height_, mat_dim_y.width_};
}
PADDLE_ENFORCE_EQ(
KX, KY,
"First matrix's width must be equal with second matrix's height.");
if (batchCountX && batchCountY) {
PADDLE_ENFORCE_EQ(
batchCountX, batchCountY,
"When Input(X) and Input(Y) are both three dimensional, they "
"must have the same batch dimension.");
if (dim_x.size() == 1 && dim_out[dim_out.size() - 2] == 1) {
std::swap(dim_out[dim_out.size() - 2], dim_out[dim_out.size() - 1]);
dim_out.resize(dim_out.size() - 1);
}
int batchCount = std::max(batchCountX, batchCountY);
std::vector<int64_t> dim_out;
if (batchCount) {
if (dim_x.size() > 3) {
dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end());
} else {
dim_out.push_back(batchCount);
}
if (dim_y.size() == 1 && dim_out[dim_out.size() - 1] == 1) {
dim_out.resize(dim_out.size() - 1);
}
if (!remove_initial_dim) {
dim_out.push_back(M);
}
if (!remove_final_dim) {
dim_out.push_back(N);
}
if (dim_out.size() == 0) {
// We don't support 0-dimensional Tensors (scalars), so instead
// treat the output as a Tensor of shape (1, ) in this case.
dim_out.push_back(1);
if (dim_out.empty()) {
dim_out = {1};
}
context->SetOutputDim("Out", framework::make_ddim(dim_out));
context->ShareLoD("X", /*->*/ "Out");
......
......@@ -15,55 +15,56 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matmul.h"
namespace paddle {
namespace operators {
namespace matmul_detail {
inline framework::DDim GetXDim(const framework::DDim& x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return framework::make_ddim({1, x_dim[0]});
}
using Tensor = framework::Tensor;
using DDim = framework::DDim;
using framework::make_ddim;
using framework::vectorize;
inline framework::DDim GetYDim(const framework::DDim& y_dim) {
if (y_dim.size() > 1) {
return y_dim;
}
return framework::make_ddim({y_dim[0], 1});
}
template <typename DeviceContext, typename T>
class MatMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
Tensor* out = context.Output<Tensor>("Out");
auto& x =
detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X");
auto& y =
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y");
math::MatMulFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), x, transpose_x, y,
transpose_y, T(1), out, T(0));
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::GetMatDim(GetXDim(x.dims()), 0,
context.Attr<bool>("transpose_X"));
auto mat_dim_b = math::GetMatDim(GetYDim(y.dims()), 0,
context.Attr<bool>("transpose_Y"));
blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0));
}
};
template <typename T>
inline Tensor Reshape(const Tensor& input, const DDim& dims) {
Tensor output;
output.ShareDataWith(input);
output.Resize(dims);
return output;
}
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
template <typename T>
Tensor CombineBatchAndM(const Tensor& input) {
Tensor output;
output.ShareDataWith(input);
inline framework::Tensor CombineBatchAndM(const framework::Tensor& input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
std::vector<int64_t> out_dims = {in_dims[0] * in_dims[1], in_dims[2]};
output.Resize(make_ddim(out_dims));
output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
}
return output;
}
......@@ -72,23 +73,57 @@ Tensor CombineBatchAndM(const Tensor& input) {
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T>
Tensor CombineBatchAndN(const DeviceContext& context, const Tensor& input) {
Tensor output;
inline framework::Tensor CombineBatchAndN(const DeviceContext& context,
const framework::Tensor& input) {
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> axis = {1, 0, 2};
math::Transpose<DeviceContext, T, 3> trans;
trans(context, input, &output, axis);
std::vector<int64_t> out_dims = {in_dims[1], in_dims[0] * in_dims[2]};
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
} else {
output.ShareDataWith(input);
if (in_dims.size() != 3) {
return input;
}
framework::Tensor output;
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> axis = {1, 0, 2};
math::Transpose<DeviceContext, T, 3> trans;
trans(context, input, &output, axis);
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output;
}
inline void NormalizeTensorShape(framework::Tensor* x,
const math::MatDim& mat_dim_x) {
int64_t h, w;
h = mat_dim_x.height_;
w = mat_dim_x.width_;
if (mat_dim_x.trans_) {
std::swap(w, h);
}
if (mat_dim_x.batch_size_) {
x->Resize({mat_dim_x.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
inline void NormalizeXYOutTensorShape(framework::Tensor* x,
framework::Tensor* y,
framework::Tensor* out, bool trans_a,
bool trans_b) {
auto x_dim = GetXDim(x->dims());
auto y_dim = GetYDim(y->dims());
auto mat_dim_x = math::GetMatDim(x_dim, 0, trans_a);
auto mat_dim_y = math::GetMatDim(y_dim, 0, trans_b);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
}
NormalizeTensorShape(x, mat_dim_x);
NormalizeTensorShape(y, mat_dim_y);
}
// Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y
// are both matrices.
......@@ -117,128 +152,91 @@ Tensor CombineBatchAndN(const DeviceContext& context, const Tensor& input) {
template <typename DeviceContext, typename T>
class MatMulGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b,
framework::Tensor* out) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::GetMatDim(a.dims(), 0, trans_a);
auto mat_dim_b = math::GetMatDim(b.dims(), 0, trans_b);
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_combine_m_a, const framework::Tensor& b,
bool trans_b, bool is_combine_m_b,
framework::Tensor* out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(
context, is_combine_m_a ? CombineBatchAndM(a)
: CombineBatchAndN<DeviceContext, T>(ctx, a),
trans_a, is_combine_m_b ? CombineBatchAndM(b)
: CombineBatchAndN<DeviceContext, T>(ctx, b),
trans_b, out);
}
}
void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
const Tensor& dout = *context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* dx = context.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = context.Output<Tensor>(framework::GradVarName("Y"));
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y");
std::vector<int64_t> x_dims = vectorize(x.dims());
std::vector<int64_t> y_dims = vectorize(y.dims());
// If X is a vector, reshape it to a matrix.
if (x_dims.size() == 1) {
x_dims.insert(x_dims.begin(), 1);
}
// If Y is a vector, reshape it to a matrix.
if (y_dims.size() == 1) {
y_dims.push_back(1);
NormalizeXYOutTensorShape(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
int batch_count = 0;
// The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
if (x_dims.size() > 3) {
batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
std::multiplies<int>());
}
// Fix the dOut dimensions.
int M = 0, N = 0, batchCountX = 0, batchCountY = 0;
switch (x_dims.size()) {
case 2:
M = transpose_x ? x_dims[1] : x_dims[0];
break;
case 3:
batchCountX = x_dims[0];
M = transpose_x ? x_dims[2] : x_dims[1];
break;
default:
batchCountX = batch_count;
size_t mat_s = x_dims.size() - 2;
M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s];
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
switch (y_dims.size()) {
case 2:
N = transpose_y ? y_dims[0] : y_dims[1];
break;
case 3:
batchCountY = y_dims[0];
N = transpose_y ? y_dims[1] : y_dims[2];
break;
default:
batchCountY = batch_count;
size_t mat_s = y_dims.size() - 2;
N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1];
}
if (batchCountX && batchCountY) {
PADDLE_ENFORCE_EQ(
batchCountX, batchCountY,
"When Input(X) and Input(Y) are both three dimensional, they "
"must have the same batch dimension.");
}
int batchCount = std::max(batchCountX, batchCountY);
std::vector<int64_t> dout_dims = {M, N};
if (batchCount) {
if (x_dims.size() > 3) {
dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2);
} else {
dout_dims.insert(dout_dims.begin(), batchCount);
}
if (transpose_x && transpose_y) {
CalcInputGrad(context, y, true, true, dout, true, false, dx);
CalcInputGrad(context, dout, true, true, x, true, false, dy);
} else if (transpose_x && !transpose_y) {
CalcInputGrad(context, y, false, false, dout, true, false, dx);
CalcInputGrad(context, x, false, false, dout, false, true, dy);
} else if (!transpose_x && transpose_y) {
CalcInputGrad(context, dout, false, false, y, false, true, dx);
CalcInputGrad(context, dout, true, true, x, false, true, dy);
} else {
CalcInputGrad(context, dout, false, false, y, true, false, dx);
CalcInputGrad(context, x, true, true, dout, false, true, dy);
}
Tensor X = Reshape<T>(x, make_ddim(x_dims));
Tensor Y = Reshape<T>(y, make_ddim(y_dims));
Tensor dOut = Reshape<T>(dout, make_ddim(dout_dims));
auto& dev_ctx = context.template device_context<DeviceContext>();
if (dx) {
dx->mutable_data<T>(context.GetPlace());
const Tensor& dOut_for_dX =
(x_dims.size() == 2 && y_dims.size() == 3)
? CombineBatchAndN<DeviceContext, T>(dev_ctx, dOut)
: dOut;
if (x_dims.size() == 2 && y_dims.size() == 3) {
Y = transpose_y ? CombineBatchAndM<T>(Y)
: CombineBatchAndN<DeviceContext, T>(dev_ctx, Y);
}
if (transpose_x) {
math::MatMulFunctor<DeviceContext, T>()(
dev_ctx, Y, transpose_y, dOut_for_dX, transpose_x, T(1), dx, T(0));
} else {
math::MatMulFunctor<DeviceContext, T>()(
dev_ctx, dOut_for_dX, transpose_x, Y, !transpose_y, T(1), dx, T(0));
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
dy->mutable_data<T>(context.GetPlace());
const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3)
? CombineBatchAndM<T>(dOut)
: dOut;
if (y_dims.size() == 2 && x_dims.size() == 3) {
X = transpose_x ? CombineBatchAndN<DeviceContext, T>(dev_ctx, X)
: CombineBatchAndM<T>(X);
dOut = CombineBatchAndM<T>(dOut);
}
if (transpose_y) {
math::MatMulFunctor<DeviceContext, T>()(
dev_ctx, dOut_for_dY, transpose_y, X, transpose_x, T(1), dy, T(0));
} else {
math::MatMulFunctor<DeviceContext, T>()(
dev_ctx, X, !transpose_x, dOut_for_dY, transpose_y, T(1), dy, T(0));
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
}
};
} // namespace matmul_detail
using matmul_detail::MatMulKernel;
using matmul_detail::MatMulGradKernel;
} // namespace operators
} // namespace paddle
......@@ -111,21 +111,24 @@ class Generator(object):
# Generate test cases for all possibilities
for dim_X in [1, 2, 3]:
for dim_Y in [1, 2, 3]:
for transpose_X in [False, True]:
for transpose_Y in [False, True]:
test_name = (
'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
dim_X, dim_Y, transpose_X, transpose_Y))
shape_X, shape_Y = generate_compatible_shapes(
dim_X, dim_Y, transpose_X, transpose_Y)
globals()[test_name] = type(test_name, (Generator, OpTest), {
'shape_X': shape_X,
'shape_Y': shape_Y,
'transpose_X': transpose_X,
'transpose_Y': transpose_Y,
})
def inject_test(dim_x, dim_y, trans_x, trans_y):
test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
dim_x, dim_y, trans_x, trans_y))
shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x,
trans_y)
globals()[test_name] = type(test_name, (Generator, OpTest), {
'shape_X': shape_x,
'shape_Y': shape_y,
'transpose_X': trans_x,
'transpose_Y': trans_y,
})
for dim_X in (1, 2, 3):
for dim_Y in (1, 2, 3):
for transose_x in (False, True):
for transose_y in (False, True):
inject_test(dim_X, dim_Y, transose_x, transose_y)
# Test case n-dim
......@@ -149,7 +152,7 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y):
return shape_X, shape_Y
# Test case n-dim
# # Test case n-dim
for dim in [4]:
for transpose_X in [False, True]:
for transpose_Y in [False, True]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册