You need to sign in or sign up before continuing.
提交 c6a6d87f 编写于 作者: Y Yu Yang

Rewrite Matmul, make code cleaner

上级 0285a2b9
...@@ -13,10 +13,47 @@ ...@@ -13,10 +13,47 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include <utility>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// Do nothing. Blas is a header only library. MatDim GetMatDim(const framework::DDim& dim, int num_flatten_cols, bool trans) {
MatDim retv;
if (num_flatten_cols > 1) {
auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols);
retv.height_ = flatten_dim[0];
retv.width_ = flatten_dim[1];
} else {
if (dim.size() == 1) {
retv.height_ = 1;
retv.width_ = dim[0];
} else if (dim.size() == 2) {
retv.height_ = dim[0];
retv.width_ = dim[1];
} else {
if (dim.size() == 3) {
retv.batch_size_ = dim[0];
retv.height_ = dim[1];
retv.width_ = dim[2];
} else {
auto dim_vec = framework::vectorize(dim);
retv.batch_size_ = 1;
for (size_t i = 0; i < dim_vec.size() - 2; ++i) {
retv.batch_size_ *= dim_vec[i];
retv.height_ = dim_vec[dim_vec.size() - 2];
retv.width_ = dim_vec[dim_vec.size() - 1];
}
}
retv.stride_ = retv.height_ * retv.width_;
}
}
if (trans) {
std::swap(retv.width_, retv.height_);
}
retv.trans_ = trans;
return retv;
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -46,6 +46,17 @@ namespace paddle { ...@@ -46,6 +46,17 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
struct MatDim {
int64_t height_;
int64_t width_;
int64_t stride_{0};
int64_t batch_size_{0};
bool trans_;
};
extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols,
bool trans);
template <typename DeviceContext> template <typename DeviceContext>
class Blas { class Blas {
public: public:
...@@ -90,6 +101,28 @@ class Blas { ...@@ -90,6 +101,28 @@ class Blas {
int K, T alpha, const T* A, const T* B, T beta, T* C, int K, T alpha, const T* A, const T* B, T beta, T* C,
int batchCount, int64_t strideA, int64_t strideB) const; int batchCount, int64_t strideA, int64_t strideB) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a,
const framework::Tensor& mat_b, const MatDim& dim_b, T alpha,
framework::Tensor* mat_out, T beta) const {
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_);
}
}
private: private:
const DeviceContext& context_; const DeviceContext& context_;
}; };
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
namespace math {
// Implements the logic of numpy matmul:
// https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
//
// but allowing also for a, b to be transposed
//
// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported
// yet.
template <typename DeviceContext, typename T>
class MatMulFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& a,
bool trans_a, const framework::Tensor& b, bool trans_b,
T alpha, framework::Tensor* out, T beta) {
auto dim_a = a.dims();
auto dim_b = b.dims();
PADDLE_ENFORCE(a.place() == b.place() && b.place() == out->place(),
"Tensors must all be in the same place.");
PADDLE_ENFORCE_GE(dim_a.size(), 1,
"Input tensor a must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_b.size(), 1,
"Input tensor b must be at least 1-dimensional.");
std::vector<int64_t> out_dim;
int64_t batch_count = 1;
if (dim_a.size() > 3) {
PADDLE_ENFORCE(dim_b.size() == dim_a.size(),
"The dimensions of X and Y must be the same, and both of "
"them should be %d-dimensional.",
dim_b.size());
// The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_a.size() - 2; ++j) {
PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j],
"The %d-th dimension of X and Y must be the same.",
j);
out_dim.push_back(dim_a[j]);
batch_count *= dim_a[j];
}
}
int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0,
strideA = 0, strideB = 0;
switch (dim_a.size()) {
case 1:
// similar to np.matmul:
// prepend dimension 1 (no transpose) or append dimension 1 (transpose)
M = trans_a ? dim_a[0] : 1;
kA = trans_a ? 1 : dim_a[0];
break;
case 2:
M = trans_a ? dim_a[1] : dim_a[0];
kA = trans_a ? dim_a[0] : dim_a[1];
break;
case 3:
batchCountA = dim_a[0];
M = trans_a ? dim_a[2] : dim_a[1];
kA = trans_a ? dim_a[1] : dim_a[2];
strideA = M * kA;
break;
default:
batchCountA = batch_count;
size_t mat_s = dim_a.size() - 2;
M = trans_a ? dim_a[mat_s + 1] : dim_a[mat_s];
kA = trans_a ? dim_a[mat_s] : dim_a[mat_s + 1];
strideA = M * kA;
}
switch (dim_b.size()) {
case 1:
// similar to np.matmul:
// append dimension 1 (no transpose) or prepend dimension 1 (transpose)
kB = trans_b ? 1 : dim_b[0];
N = trans_b ? dim_b[0] : 1;
break;
case 2:
kB = trans_b ? dim_b[1] : dim_b[0];
N = trans_b ? dim_b[0] : dim_b[1];
break;
case 3:
batchCountB = dim_b[0];
kB = trans_b ? dim_b[2] : dim_b[1];
N = trans_b ? dim_b[1] : dim_b[2];
strideB = kB * N;
break;
default:
batchCountB = batch_count;
size_t mat_s = dim_b.size() - 2;
kB = trans_b ? dim_b[mat_s + 1] : dim_b[mat_s];
N = trans_b ? dim_b[mat_s] : dim_b[mat_s + 1];
strideB = kB * N;
}
PADDLE_ENFORCE_EQ(
kA, kB,
"First matrix's width must be equal with second matrix's height.");
if (batchCountA && batchCountB) {
PADDLE_ENFORCE_EQ(
batchCountA, batchCountB,
"When input tensors a and b are both batched, they must have the "
"same batch dimension.");
}
int batchCount = std::max(batchCountA, batchCountB);
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
auto blas = GetBlas<DeviceContext, T>(context);
if (!batchCount) {
// regular matrix multiplication
blas.GEMM(transA, transB, M, N, kA, alpha, a.data<T>(), b.data<T>(), beta,
out->data<T>());
} else {
// batched matrix multiplication
blas.BatchedGEMM(transA, transB, M, N, kA, alpha, a.data<T>(),
b.data<T>(), beta, out->data<T>(), batchCount, strideA,
strideB);
}
}
};
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -36,121 +36,39 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -36,121 +36,39 @@ class MatMulOp : public framework::OperatorWithKernel {
auto dim_x = context->GetInputDim("X"); auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y"); auto dim_y = context->GetInputDim("Y");
bool transpose_x = context->Attrs().Get<bool>("transpose_X");
bool transpose_y = context->Attrs().Get<bool>("transpose_Y");
PADDLE_ENFORCE_GE(dim_x.size(), 1,
"Input tensor X must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_y.size(), 1,
"Input tensor Y must be at least 1-dimensional.");
std::vector<int64_t> out_dim;
int64_t batch_count = 1;
if (dim_x.size() > 3) {
PADDLE_ENFORCE_EQ(
dim_y.size(), dim_x.size(),
"The dimensions of X and Y must be the same, and both of "
"them should be %d-dimensional.",
dim_x.size());
// The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_x.size() - 2; ++j) {
PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j],
"The %d-th dimension of X and Y must be the same.",
j);
out_dim.push_back(dim_x[j]);
batch_count *= dim_x[j];
}
}
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; auto mat_dim_x = math::GetMatDim(GetXDim(dim_x), 0,
bool remove_initial_dim = false, remove_final_dim = false; context->Attrs().Get<bool>("transpose_X"));
auto mat_dim_y = math::GetMatDim(GetYDim(dim_y), 0,
switch (dim_x.size()) { context->Attrs().Get<bool>("transpose_Y"));
case 1:
if (transpose_x) {
M = dim_x[0];
KX = 1;
} else {
M = 1;
KX = dim_x[0];
remove_initial_dim = true;
}
break;
case 2:
M = transpose_x ? dim_x[1] : dim_x[0];
KX = transpose_x ? dim_x[0] : dim_x[1];
break;
case 3:
batchCountX = dim_x[0];
M = transpose_x ? dim_x[2] : dim_x[1];
KX = transpose_x ? dim_x[1] : dim_x[2];
break;
default:
batchCountX = batch_count;
size_t mat_s = dim_x.size() - 2;
M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s];
KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1];
break;
}
switch (dim_y.size()) { PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
case 1: PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
if (transpose_y) { mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
N = dim_y[0]; std::vector<int64_t> dim_out;
KY = 1; if (mat_dim_x.batch_size_ != 0) {
} else { dim_out = framework::vectorize(dim_x);
N = 1; dim_out[dim_out.size() - 2] = mat_dim_x.height_;
KY = dim_y[0]; dim_out[dim_out.size() - 1] = mat_dim_y.width_;
remove_final_dim = true; } else if (mat_dim_y.batch_size_ != 0) {
} dim_out = framework::vectorize(dim_y);
break; dim_out[dim_out.size() - 2] = mat_dim_x.height_;
case 2: dim_out[dim_out.size() - 1] = mat_dim_y.width_;
KY = transpose_y ? dim_y[1] : dim_y[0]; } else {
N = transpose_y ? dim_y[0] : dim_y[1]; dim_out = {mat_dim_x.height_, mat_dim_y.width_};
break;
case 3:
batchCountY = dim_y[0];
KY = transpose_y ? dim_y[2] : dim_y[1];
N = transpose_y ? dim_y[1] : dim_y[2];
break;
default:
batchCountY = batch_count;
size_t mat_s = dim_y.size() - 2;
KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s];
N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1];
} }
PADDLE_ENFORCE_EQ( if (dim_x.size() == 1 && dim_out[dim_out.size() - 2] == 1) {
KX, KY, std::swap(dim_out[dim_out.size() - 2], dim_out[dim_out.size() - 1]);
"First matrix's width must be equal with second matrix's height."); dim_out.resize(dim_out.size() - 1);
if (batchCountX && batchCountY) {
PADDLE_ENFORCE_EQ(
batchCountX, batchCountY,
"When Input(X) and Input(Y) are both three dimensional, they "
"must have the same batch dimension.");
} }
int batchCount = std::max(batchCountX, batchCountY);
std::vector<int64_t> dim_out; if (dim_y.size() == 1 && dim_out[dim_out.size() - 1] == 1) {
if (batchCount) { dim_out.resize(dim_out.size() - 1);
if (dim_x.size() > 3) {
dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end());
} else {
dim_out.push_back(batchCount);
}
} }
if (!remove_initial_dim) {
dim_out.push_back(M); if (dim_out.empty()) {
} dim_out = {1};
if (!remove_final_dim) {
dim_out.push_back(N);
}
if (dim_out.size() == 0) {
// We don't support 0-dimensional Tensors (scalars), so instead
// treat the output as a Tensor of shape (1, ) in this case.
dim_out.push_back(1);
} }
context->SetOutputDim("Out", framework::make_ddim(dim_out)); context->SetOutputDim("Out", framework::make_ddim(dim_out));
context->ShareLoD("X", /*->*/ "Out"); context->ShareLoD("X", /*->*/ "Out");
......
...@@ -15,55 +15,56 @@ limitations under the License. */ ...@@ -15,55 +15,56 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matmul.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace matmul_detail { inline framework::DDim GetXDim(const framework::DDim& x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return framework::make_ddim({1, x_dim[0]});
}
using Tensor = framework::Tensor; inline framework::DDim GetYDim(const framework::DDim& y_dim) {
using DDim = framework::DDim; if (y_dim.size() > 1) {
using framework::make_ddim; return y_dim;
using framework::vectorize; }
return framework::make_ddim({y_dim[0], 1});
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulKernel : public framework::OpKernel<T> { class MatMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X"); auto& x =
const Tensor& y = *context.Input<Tensor>("Y"); detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X");
Tensor* out = context.Output<Tensor>("Out"); auto& y =
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y");
math::MatMulFunctor<DeviceContext, T>()( auto blas = math::GetBlas<DeviceContext, T>(context);
context.template device_context<DeviceContext>(), x, transpose_x, y, auto mat_dim_a = math::GetMatDim(GetXDim(x.dims()), 0,
transpose_y, T(1), out, T(0)); context.Attr<bool>("transpose_X"));
auto mat_dim_b = math::GetMatDim(GetYDim(y.dims()), 0,
context.Attr<bool>("transpose_Y"));
blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0));
} }
}; };
template <typename T>
inline Tensor Reshape(const Tensor& input, const DDim& dims) {
Tensor output;
output.ShareDataWith(input);
output.Resize(dims);
return output;
}
// Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
template <typename T> inline framework::Tensor CombineBatchAndM(const framework::Tensor& input) {
Tensor CombineBatchAndM(const Tensor& input) { auto output = input;
Tensor output;
output.ShareDataWith(input);
auto in_dims = input.dims(); auto in_dims = input.dims();
if (in_dims.size() == 3) { if (in_dims.size() == 3) {
std::vector<int64_t> out_dims = {in_dims[0] * in_dims[1], in_dims[2]}; output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
output.Resize(make_ddim(out_dims));
} }
return output; return output;
} }
...@@ -72,23 +73,57 @@ Tensor CombineBatchAndM(const Tensor& input) { ...@@ -72,23 +73,57 @@ Tensor CombineBatchAndM(const Tensor& input) {
// (Warning: This requires transposing data and writes into new memory.) // (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
Tensor CombineBatchAndN(const DeviceContext& context, const Tensor& input) { inline framework::Tensor CombineBatchAndN(const DeviceContext& context,
Tensor output; const framework::Tensor& input) {
auto in_dims = input.dims(); auto in_dims = input.dims();
if (in_dims.size() == 3) { if (in_dims.size() != 3) {
output.Resize({in_dims[1], in_dims[0], in_dims[2]}); return input;
output.mutable_data<T>(context.GetPlace());
std::vector<int> axis = {1, 0, 2};
math::Transpose<DeviceContext, T, 3> trans;
trans(context, input, &output, axis);
std::vector<int64_t> out_dims = {in_dims[1], in_dims[0] * in_dims[2]};
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
} else {
output.ShareDataWith(input);
} }
framework::Tensor output;
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> axis = {1, 0, 2};
math::Transpose<DeviceContext, T, 3> trans;
trans(context, input, &output, axis);
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output; return output;
} }
inline void NormalizeTensorShape(framework::Tensor* x,
const math::MatDim& mat_dim_x) {
int64_t h, w;
h = mat_dim_x.height_;
w = mat_dim_x.width_;
if (mat_dim_x.trans_) {
std::swap(w, h);
}
if (mat_dim_x.batch_size_) {
x->Resize({mat_dim_x.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
inline void NormalizeXYOutTensorShape(framework::Tensor* x,
framework::Tensor* y,
framework::Tensor* out, bool trans_a,
bool trans_b) {
auto x_dim = GetXDim(x->dims());
auto y_dim = GetYDim(y->dims());
auto mat_dim_x = math::GetMatDim(x_dim, 0, trans_a);
auto mat_dim_y = math::GetMatDim(y_dim, 0, trans_b);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
}
NormalizeTensorShape(x, mat_dim_x);
NormalizeTensorShape(y, mat_dim_y);
}
// Using dimensional constraints on matrix multiplication, it is // Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y // straight-forward to check the following table for when X and Y
// are both matrices. // are both matrices.
...@@ -117,128 +152,91 @@ Tensor CombineBatchAndN(const DeviceContext& context, const Tensor& input) { ...@@ -117,128 +152,91 @@ Tensor CombineBatchAndN(const DeviceContext& context, const Tensor& input) {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulGradKernel : public framework::OpKernel<T> { class MatMulGradKernel : public framework::OpKernel<T> {
public: public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b,
framework::Tensor* out) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::GetMatDim(a.dims(), 0, trans_a);
auto mat_dim_b = math::GetMatDim(b.dims(), 0, trans_b);
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_combine_m_a, const framework::Tensor& b,
bool trans_b, bool is_combine_m_b,
framework::Tensor* out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(
context, is_combine_m_a ? CombineBatchAndM(a)
: CombineBatchAndN<DeviceContext, T>(ctx, a),
trans_a, is_combine_m_b ? CombineBatchAndM(b)
: CombineBatchAndN<DeviceContext, T>(ctx, b),
trans_b, out);
}
}
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X"); auto x = *context.Input<framework::Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y"); auto y = *context.Input<framework::Tensor>("Y");
const Tensor& dout = *context.Input<Tensor>(framework::GradVarName("Out")); auto dout =
Tensor* dx = context.Output<Tensor>(framework::GradVarName("X")); *context.Input<framework::Tensor>(framework::GradVarName("Out"));
Tensor* dy = context.Output<Tensor>(framework::GradVarName("Y")); auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
bool transpose_x = context.Attr<bool>("transpose_X"); bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y"); bool transpose_y = context.Attr<bool>("transpose_Y");
std::vector<int64_t> x_dims = vectorize(x.dims()); NormalizeXYOutTensorShape(&x, &y, &dout, transpose_x, transpose_y);
std::vector<int64_t> y_dims = vectorize(y.dims()); framework::DDim dx_dims;
if (dx) {
// If X is a vector, reshape it to a matrix. dx_dims = dx->dims();
if (x_dims.size() == 1) { if (dx_dims != x.dims()) {
x_dims.insert(x_dims.begin(), 1); dx->Resize(x.dims());
} }
// If Y is a vector, reshape it to a matrix.
if (y_dims.size() == 1) {
y_dims.push_back(1);
} }
int batch_count = 0; framework::DDim dy_dims;
// The first rank-2 dimensions are accumulated on the batch_count, and the if (dy) {
// last two dimensions are used for matrix multiplication. dy_dims = dy->dims();
if (x_dims.size() > 3) { if (dy_dims != y.dims()) {
batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, dy->Resize(y.dims());
std::multiplies<int>()); }
}
// Fix the dOut dimensions.
int M = 0, N = 0, batchCountX = 0, batchCountY = 0;
switch (x_dims.size()) {
case 2:
M = transpose_x ? x_dims[1] : x_dims[0];
break;
case 3:
batchCountX = x_dims[0];
M = transpose_x ? x_dims[2] : x_dims[1];
break;
default:
batchCountX = batch_count;
size_t mat_s = x_dims.size() - 2;
M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s];
} }
switch (y_dims.size()) { if (transpose_x && transpose_y) {
case 2: CalcInputGrad(context, y, true, true, dout, true, false, dx);
N = transpose_y ? y_dims[0] : y_dims[1]; CalcInputGrad(context, dout, true, true, x, true, false, dy);
break; } else if (transpose_x && !transpose_y) {
case 3: CalcInputGrad(context, y, false, false, dout, true, false, dx);
batchCountY = y_dims[0]; CalcInputGrad(context, x, false, false, dout, false, true, dy);
N = transpose_y ? y_dims[1] : y_dims[2]; } else if (!transpose_x && transpose_y) {
break; CalcInputGrad(context, dout, false, false, y, false, true, dx);
default: CalcInputGrad(context, dout, true, true, x, false, true, dy);
batchCountY = batch_count; } else {
size_t mat_s = y_dims.size() - 2; CalcInputGrad(context, dout, false, false, y, true, false, dx);
N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; CalcInputGrad(context, x, true, true, dout, false, true, dy);
}
if (batchCountX && batchCountY) {
PADDLE_ENFORCE_EQ(
batchCountX, batchCountY,
"When Input(X) and Input(Y) are both three dimensional, they "
"must have the same batch dimension.");
}
int batchCount = std::max(batchCountX, batchCountY);
std::vector<int64_t> dout_dims = {M, N};
if (batchCount) {
if (x_dims.size() > 3) {
dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2);
} else {
dout_dims.insert(dout_dims.begin(), batchCount);
}
} }
Tensor X = Reshape<T>(x, make_ddim(x_dims));
Tensor Y = Reshape<T>(y, make_ddim(y_dims));
Tensor dOut = Reshape<T>(dout, make_ddim(dout_dims));
auto& dev_ctx = context.template device_context<DeviceContext>();
if (dx) { if (dx) {
dx->mutable_data<T>(context.GetPlace()); if (dx_dims != x.dims()) {
const Tensor& dOut_for_dX = dx->Resize(dx_dims);
(x_dims.size() == 2 && y_dims.size() == 3)
? CombineBatchAndN<DeviceContext, T>(dev_ctx, dOut)
: dOut;
if (x_dims.size() == 2 && y_dims.size() == 3) {
Y = transpose_y ? CombineBatchAndM<T>(Y)
: CombineBatchAndN<DeviceContext, T>(dev_ctx, Y);
}
if (transpose_x) {
math::MatMulFunctor<DeviceContext, T>()(
dev_ctx, Y, transpose_y, dOut_for_dX, transpose_x, T(1), dx, T(0));
} else {
math::MatMulFunctor<DeviceContext, T>()(
dev_ctx, dOut_for_dX, transpose_x, Y, !transpose_y, T(1), dx, T(0));
} }
} }
if (dy) { if (dy) {
dy->mutable_data<T>(context.GetPlace()); if (dy_dims != y.dims()) {
const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3) dy->Resize(dy_dims);
? CombineBatchAndM<T>(dOut)
: dOut;
if (y_dims.size() == 2 && x_dims.size() == 3) {
X = transpose_x ? CombineBatchAndN<DeviceContext, T>(dev_ctx, X)
: CombineBatchAndM<T>(X);
dOut = CombineBatchAndM<T>(dOut);
}
if (transpose_y) {
math::MatMulFunctor<DeviceContext, T>()(
dev_ctx, dOut_for_dY, transpose_y, X, transpose_x, T(1), dy, T(0));
} else {
math::MatMulFunctor<DeviceContext, T>()(
dev_ctx, X, !transpose_x, dOut_for_dY, transpose_y, T(1), dy, T(0));
} }
} }
} }
}; };
} // namespace matmul_detail
using matmul_detail::MatMulKernel;
using matmul_detail::MatMulGradKernel;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -111,21 +111,24 @@ class Generator(object): ...@@ -111,21 +111,24 @@ class Generator(object):
# Generate test cases for all possibilities # Generate test cases for all possibilities
for dim_X in [1, 2, 3]: def inject_test(dim_x, dim_y, trans_x, trans_y):
for dim_Y in [1, 2, 3]: test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
for transpose_X in [False, True]: dim_x, dim_y, trans_x, trans_y))
for transpose_Y in [False, True]: shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x,
test_name = ( trans_y)
'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( globals()[test_name] = type(test_name, (Generator, OpTest), {
dim_X, dim_Y, transpose_X, transpose_Y)) 'shape_X': shape_x,
shape_X, shape_Y = generate_compatible_shapes( 'shape_Y': shape_y,
dim_X, dim_Y, transpose_X, transpose_Y) 'transpose_X': trans_x,
globals()[test_name] = type(test_name, (Generator, OpTest), { 'transpose_Y': trans_y,
'shape_X': shape_X, })
'shape_Y': shape_Y,
'transpose_X': transpose_X,
'transpose_Y': transpose_Y, for dim_X in (1, 2, 3):
}) for dim_Y in (1, 2, 3):
for transose_x in (False, True):
for transose_y in (False, True):
inject_test(dim_X, dim_Y, transose_x, transose_y)
# Test case n-dim # Test case n-dim
...@@ -149,7 +152,7 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y): ...@@ -149,7 +152,7 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y):
return shape_X, shape_Y return shape_X, shape_Y
# Test case n-dim # # Test case n-dim
for dim in [4]: for dim in [4]:
for transpose_X in [False, True]: for transpose_X in [False, True]:
for transpose_Y in [False, True]: for transpose_Y in [False, True]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册