diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index ee485bd1711e21b86cdf65fdb2f5f0793e42beb4..3099b06aecb5830a8685a8742336a769cc308691 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" @@ -350,20 +351,158 @@ class MatMulV2Kernel : public framework::OpKernel { } }; +// Reshape a rank-3 tensor from P x M x N to (P * M) x N. +// Identity op if the tensor is not of rank 3. +static framework::Tensor FoldInitDims(const framework::Tensor& input) { + auto output = input; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + } + return output; +} + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, + const framework::Tensor& input) { + auto in_dims = input.dims(); + if (in_dims.size() != 3) { + return input; + } + framework::Tensor output; + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); + output.mutable_data(context.GetPlace()); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(context, input, &output, axis); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); + return output; +} + +/** + * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the + * original x_dim is returned. + */ +static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return framework::make_ddim({1, x_dim[0]}); +} + +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return framework::make_ddim({y_dim[0], 1}); +} + +/** + * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. + * + * The shape would be [BatchSize, H, W] or [H, W]. + * If transposed, `H,W` will be swapped. + */ +static void ReshapeTensorIntoMatrixSequence( + framework::Tensor* x, const math::MatDescriptor& descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, + framework::Tensor* y, + framework::Tensor* out, bool trans_x, + bool trans_y) { + auto x_dim = RowMatrixFromVector(x->dims()); + auto y_dim = ColumnMatrixFromVector(y->dims()); + auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, mat_dim_y.width_}); + } + + ReshapeTensorIntoMatrixSequence(x, mat_dim_x); + ReshapeTensorIntoMatrixSequence(y, mat_dim_y); +} + template class MatMulV2GradKernel : public framework::OpKernel { public: + void MatMul(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, + framework::Tensor* out) const { + out->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + if (a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, + static_cast(0)); + } + + void CalcInputGrad(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor& b, + bool trans_b, bool is_fold_init_dims_b, + framework::Tensor* out) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, out); + } else { + auto& ctx = context.template device_context(); + MatMul(context, is_fold_init_dims_a + ? FoldInitDims(a) + : FoldHeadAndLastDims(ctx, a), + trans_a, is_fold_init_dims_b + ? FoldInitDims(b) + : FoldHeadAndLastDims(ctx, b), + trans_b, out); + } + } + void Compute(const framework::ExecutionContext& ctx) const override { - auto* X = ctx.Input("X"); - auto* Y = ctx.Input("Y"); - auto* dOut = ctx.Input(framework::GradVarName("Out")); - bool trans_x = ctx.Attr("trans_x"); - bool trans_y = ctx.Attr("trans_y"); + // auto* X = ctx.Input("X"); + // auto* Y = ctx.Input("Y"); + // auto* dOut = ctx.Input(framework::GradVarName("Out")); + bool transpose_x = ctx.Attr("trans_x"); + bool transpose_y = ctx.Attr("trans_y"); + + auto x = *ctx.Input("X"); + auto y = *ctx.Input("Y"); + auto dout = *ctx.Input(framework::GradVarName("Out")); // get dims - std::vector x_dims = vectorize(X->dims()); - std::vector y_dims = vectorize(Y->dims()); - std::vector dout_dims = vectorize(dOut->dims()); + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); int x_ndim = x_dims.size(); int y_ndim = y_dims.size(); @@ -372,115 +511,156 @@ class MatMulV2GradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - // x's or y's dim = 1 + // Case1 : x's or y's dim = 1 if (x_ndim == 1 && y_ndim == 1) { if (dx) dx->mutable_data(ctx.GetPlace()); if (dy) dy->mutable_data(ctx.GetPlace()); - if (dOut->numel() == 1) { - DotGradFunction(X, Y, dOut, dx, dy, ctx); + if (dout.numel() == 1) { + DotGradFunction(&x, &y, &dout, dx, dy, ctx); return; } } - // It is very tricky. For this broadcast, currently using the reduce sum to - // get gradient. - if (x_ndim == 1) { - x_dims.insert(x_dims.begin() + 0, 1); - x_ndim += 1; - if (trans_x) - dout_dims.push_back(1); - else - dout_dims.insert(dout_dims.begin() + ndim - 1, 1); - ndim += 1; - } - if (y_ndim == 1) { - y_dims.push_back(1); - y_ndim += 1; - if (trans_y) - dout_dims.insert(dout_dims.begin() + ndim - 1, 1); - else - dout_dims.push_back(1); - ndim += 1; + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, + y_dims.cbegin()); } - // the normal case - Tensor dx_help, dy_help; - if (trans_x) { - if (trans_y) { - // X'Y': dA = Y'G', dB = G'X' - if (dx) - MatMulFunction(Y, dOut, y_dims, dout_dims, &dx_help, - true, true, ctx); - if (dy) - MatMulFunction(dOut, X, dout_dims, x_dims, &dy_help, - true, true, ctx); + VLOG(0) << "is_broadcast: " << is_broadcast; + // Case2: no broadcast or no batch size, it aims to speed and it is same as + // matmul in old version. + if (!is_broadcast) { + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + if (transpose_x && transpose_y) { + CalcInputGrad(ctx, y, true, true, dout, true, false, dx); + CalcInputGrad(ctx, dout, true, true, x, true, false, dy); + } else if (transpose_x) { + CalcInputGrad(ctx, y, false, false, dout, true, false, dx); + CalcInputGrad(ctx, x, false, false, dout, false, true, dy); + } else if (transpose_y) { + CalcInputGrad(ctx, dout, false, false, y, false, true, dx); + CalcInputGrad(ctx, dout, true, true, x, false, true, dy); } else { - // X'Y: dX = YG', dY = XG - if (dx) - MatMulFunction(Y, dOut, y_dims, dout_dims, &dx_help, - false, true, ctx); - if (dy) - MatMulFunction(X, dOut, x_dims, dout_dims, &dy_help, - false, false, ctx); + CalcInputGrad(ctx, dout, false, false, y, true, false, dx); + CalcInputGrad(ctx, x, true, true, dout, false, true, dy); + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } } } else { - if (trans_y) { - // XY': dX = GY, dY = G'X - if (dx) - MatMulFunction(dOut, Y, dout_dims, y_dims, &dx_help, - false, false, ctx); - if (dy) - MatMulFunction(dOut, X, dout_dims, x_dims, &dy_help, - true, false, ctx); + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + Tensor dx_help, dy_help; + if (transpose_x) { + if (transpose_y) { + // X'Y': dA = Y'G', dB = G'X' + if (dx) + MatMulFunction(&y, &dout, y_dims, dout_dims, + &dx_help, true, true, ctx); + if (dy) + MatMulFunction(&dout, &x, dout_dims, x_dims, + &dy_help, true, true, ctx); + } else { + // X'Y: dX = YG', dY = XG + if (dx) + MatMulFunction(&y, &dout, y_dims, dout_dims, + &dx_help, false, true, ctx); + if (dy) + MatMulFunction(&x, &dout, x_dims, dout_dims, + &dy_help, false, false, ctx); + } } else { - // XY: dX = GY', dY = X'G - if (dx) - MatMulFunction(dOut, Y, dout_dims, y_dims, &dx_help, - false, true, ctx); - if (dy) - MatMulFunction(X, dOut, x_dims, dout_dims, &dy_help, - true, false, ctx); + if (transpose_y) { + // XY': dX = GY, dY = G'X + if (dx) + MatMulFunction(&dout, &y, dout_dims, y_dims, + &dx_help, false, false, ctx); + if (dy) + MatMulFunction(&dout, &x, dout_dims, x_dims, + &dy_help, true, false, ctx); + } else { + // XY: dX = GY', dY = X'G + if (dx) + MatMulFunction(&dout, &y, dout_dims, y_dims, + &dx_help, false, true, ctx); + if (dy) + MatMulFunction(&x, &dout, x_dims, dout_dims, + &dy_help, true, false, ctx); + } } - } - // get help dims - const std::vector dx_help_dims = vectorize(dx_help.dims()); - const std::vector dy_help_dims = vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); + + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill(dx_broadcast_dims.data(), + dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill(dy_broadcast_dims.data(), + dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + + // reduce sum to get grad by ReduceSum + if (dx) { + dx->Resize(dx_help.dims()); + ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, + ctx); + dx->Resize(x.dims()); } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); + if (dy) { + dy->Resize(dy_help.dims()); + ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, + ctx); + dy->Resize(y.dims()); } } - // reduce sum to get grad by ReduceSum - if (dx) { - dx->Resize(dx_help.dims()); - ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, - ctx); - dx->Resize(X->dims()); - } - if (dy) { - dy->Resize(dy_help.dims()); - ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, - ctx); - dy->Resize(Y->dims()); - } } };