From 1c6dcfd9c901fb19dae90d591eb6e06f17da5dc6 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Sun, 27 Mar 2022 06:50:29 +0200 Subject: [PATCH] fix reshape+transpose+matmul (#40948) --- .../operators/mkldnn/matmul_mkldnn_op.cc | 45 ++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 4415fbc8cb..f4137733e3 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -235,6 +235,47 @@ class MatMulMKLDNNHandler out_strides; }; + phi::DDim GetDimForInput(const ExecutionContext& ctx, + std::string input_name) { + auto shape = ctx.Attr>("fused_reshape_" + input_name); + auto axis = ctx.Attr>("fused_transpose_" + input_name); + auto input_dims = ctx.Input(input_name)->dims(); + if (!shape.empty() && !axis.empty()) { + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT( + i, input_dims.size(), + paddle::platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", input_name, + i, input_dims.size())); + shape[i] = input_dims.at(i); + } + } + } + + // if "-1" is present then one of reshape dims must be infered + auto it_negative = std::find(shape.begin(), shape.end(), -1); + if (it_negative != shape.end()) { + int64_t dim_product = 1; + for (int i = 0; i < input_dims.size(); i++) { + dim_product *= input_dims.at(i); + } + + int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, + std::multiplies()); + int index = std::distance(shape.begin(), it_negative); + shape[index] = dim_product / shape_product; + } + + return input_dims.reshape(shape).transpose(axis); + } + return input_dims; + } + std::pair GetInputDimsAndStrides( const ExecutionContext& ctx, std::string input_name) { auto shape = ctx.Attr>("fused_reshape_" + input_name); @@ -342,8 +383,8 @@ class MatMulMKLDNNHandler batch_size_ = 1; if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { - auto& x_dims = ctx.Input("X")->dims(); - auto& y_dims = ctx.Input("Y")->dims(); + auto x_dims = GetDimForInput(ctx, "X"); + auto y_dims = GetDimForInput(ctx, "Y"); batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0]; x_bs /= batch_size_; y_bs /= batch_size_; -- GitLab