diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index b4c18c42a7619b3f6533094013e497bca98e9d84..5c3dd0cb1234aff810cbb480c9ee40db37eb6363 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -103,6 +103,34 @@ static paddle::framework::DDim ColumnMatrixDimsFromVector( return y_dim.size() > 1 ? y_dim : phi::make_ddim({y_dim[0], 1}); } +phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) { + auto shape = ctx.Attr>("fused_reshape_" + input_name); + auto axis = ctx.Attr>("fused_transpose_" + input_name); + auto input_dims = ctx.Input(input_name)->dims(); + if (!shape.empty() && !axis.empty()) { + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT(i, + input_dims.size(), + paddle::platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", + input_name, + i, + input_dims.size())); + shape[i] = input_dims.at(i); + } + } + } + + return input_dims.reshape(shape).transpose(axis); + } + return input_dims; +} + template class MatMulMKLDNNHandler : public paddle::platform::MKLDNNHandlerNoCachingT { @@ -235,36 +263,6 @@ class MatMulMKLDNNHandler out_strides; }; - phi::DDim GetDimForInput(const ExecutionContext &ctx, - std::string input_name) { - auto shape = ctx.Attr>("fused_reshape_" + input_name); - auto axis = ctx.Attr>("fused_transpose_" + input_name); - auto input_dims = ctx.Input(input_name)->dims(); - if (!shape.empty() && !axis.empty()) { - auto it_zero = std::find(shape.begin(), shape.end(), 0); - if (it_zero != shape.end()) { - for (uint64_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - PADDLE_ENFORCE_LT( - i, - input_dims.size(), - paddle::platform::errors::InvalidArgument( - "The index of 0 in fused_reshape_%s ", - "should be less than output dim size, ", - "but the index is %d and output dim size is %d", - input_name, - i, - input_dims.size())); - shape[i] = input_dims.at(i); - } - } - } - - return input_dims.reshape(shape).transpose(axis); - } - return input_dims; - } - std::pair GetInputDimsAndStrides( const ExecutionContext &ctx, std::string input_name) { auto shape = ctx.Attr>("fused_reshape_" + input_name); @@ -599,6 +597,23 @@ std::vector GetInputStrides(const ExecutionContext &ctx, auto input_dims = ctx.Input(input_name)->dims(); auto new_dims = input_dims; if (!shape.empty() && !axis.empty()) { + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT(i, + input_dims.size(), + paddle::platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", + input_name, + i, + input_dims.size())); + shape[i] = input_dims.at(i); + } + } + } new_dims = input_dims.reshape(shape).transpose(axis); } @@ -693,18 +708,6 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, out->set_format(format); } -paddle::framework::DDim GetDimForInput( - const paddle::framework::ExecutionContext &ctx, - const std::string &input_name) { - auto shape = ctx.Attr>("fused_reshape_" + input_name); - auto axis = ctx.Attr>("fused_transpose_" + input_name); - auto dim = ctx.Input(input_name)->dims(); - if (!shape.empty() && !axis.empty()) { - dim = dim.reshape(shape).transpose(axis); - } - return dim; -} - template class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { public: