提交 849442ef 编写于 作者: F ForFishes

fix the speed&memory of matmul

上级 a85592bc
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -350,20 +351,158 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
}
};
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static framework::Tensor FoldInitDims(const framework::Tensor& input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
}
return output;
}
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T>
static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
const framework::Tensor& input) {
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
framework::Tensor output;
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> axis = {1, 0, 2};
math::Transpose<DeviceContext, T, 3> trans;
trans(context, input, &output, axis);
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output;
}
/**
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
*/
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return framework::make_ddim({1, x_dim[0]});
}
/**
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned.
*/
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
if (y_dim.size() > 1) {
return y_dim;
}
return framework::make_ddim({y_dim[0], 1});
}
/**
* Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
*
* The shape would be [BatchSize, H, W] or [H, W].
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorIntoMatrixSequence(
framework::Tensor* x, const math::MatDescriptor& descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
framework::Tensor* y,
framework::Tensor* out, bool trans_x,
bool trans_y) {
auto x_dim = RowMatrixFromVector(x->dims());
auto y_dim = ColumnMatrixFromVector(y->dims());
auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
}
ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b,
framework::Tensor* out) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
if (a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast<T>(1), out,
static_cast<T>(0));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor& b,
bool trans_b, bool is_fold_init_dims_b,
framework::Tensor* out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
trans_a, is_fold_init_dims_b
? FoldInitDims(b)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, b),
trans_b, out);
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y");
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
// auto* X = ctx.Input<Tensor>("X");
// auto* Y = ctx.Input<Tensor>("Y");
// auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
bool transpose_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y");
auto x = *ctx.Input<framework::Tensor>("X");
auto y = *ctx.Input<framework::Tensor>("Y");
auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
// get dims
std::vector<std::int64_t> x_dims = vectorize(X->dims());
std::vector<std::int64_t> y_dims = vectorize(Y->dims());
std::vector<std::int64_t> dout_dims = vectorize(dOut->dims());
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
......@@ -372,76 +511,115 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// x's or y's dim = 1
// Case1 : x's or y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
if (dOut->numel() == 1) {
DotGradFunction<DeviceContext, T>(X, Y, dOut, dx, dy, ctx);
if (dout.numel() == 1) {
DotGradFunction<DeviceContext, T>(&x, &y, &dout, dx, dy, ctx);
return;
}
}
// It is very tricky. For this broadcast, currently using the reduce sum to
// get gradient.
if (x_ndim == 1) {
x_dims.insert(x_dims.begin() + 0, 1);
x_ndim += 1;
if (trans_x)
dout_dims.push_back(1);
else
dout_dims.insert(dout_dims.begin() + ndim - 1, 1);
ndim += 1;
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2,
y_dims.cbegin());
}
if (y_ndim == 1) {
y_dims.push_back(1);
y_ndim += 1;
if (trans_y)
dout_dims.insert(dout_dims.begin() + ndim - 1, 1);
else
dout_dims.push_back(1);
ndim += 1;
VLOG(0) << "is_broadcast: " << is_broadcast;
// Case2: no broadcast or no batch size, it aims to speed and it is same as
// matmul in old version.
if (!is_broadcast) {
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
if (transpose_x && transpose_y) {
CalcInputGrad(ctx, y, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x, true, false, dy);
} else if (transpose_x) {
CalcInputGrad(ctx, y, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x, false, false, dout, false, true, dy);
} else if (transpose_y) {
CalcInputGrad(ctx, dout, false, false, y, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x, false, true, dy);
} else {
CalcInputGrad(ctx, dout, false, false, y, true, false, dx);
CalcInputGrad(ctx, x, true, true, dout, false, true, dy);
}
// the normal case
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
if (trans_x) {
if (trans_y) {
if (transpose_x) {
if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<DeviceContext, T>(Y, dOut, y_dims, dout_dims, &dx_help,
true, true, ctx);
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
&dx_help, true, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(dOut, X, dout_dims, x_dims, &dy_help,
true, true, ctx);
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
&dy_help, true, true, ctx);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<DeviceContext, T>(Y, dOut, y_dims, dout_dims, &dx_help,
false, true, ctx);
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(X, dOut, x_dims, dout_dims, &dy_help,
false, false, ctx);
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
&dy_help, false, false, ctx);
}
} else {
if (trans_y) {
if (transpose_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<DeviceContext, T>(dOut, Y, dout_dims, y_dims, &dx_help,
false, false, ctx);
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
&dx_help, false, false, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(dOut, X, dout_dims, x_dims, &dy_help,
true, false, ctx);
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
&dy_help, true, false, ctx);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<DeviceContext, T>(dOut, Y, dout_dims, y_dims, &dx_help,
false, true, ctx);
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(X, dOut, x_dims, dout_dims, &dy_help,
true, false, ctx);
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
&dy_help, true, false, ctx);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
......@@ -468,18 +646,20 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
dy_reduce_dims.push_back(idx);
}
}
// reduce sum to get grad by ReduceSum
if (dx) {
dx->Resize(dx_help.dims());
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
ctx);
dx->Resize(X->dims());
dx->Resize(x.dims());
}
if (dy) {
dy->Resize(dy_help.dims());
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
ctx);
dy->Resize(Y->dims());
dy->Resize(y.dims());
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册