diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 8a54b04f4d8021c5989493ce41d829c15b467ddf..5b49a0d591edd9b8bd9403a2f330a16cc0efe8ec 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -198,15 +198,15 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // "embedding_fc_lstm_fuse_pass", // // TODO(wilber): fix correctness problem. // "fc_lstm_fuse_pass", // - "mul_lstm_fuse_pass", // - "fc_gru_fuse_pass", // - "mul_gru_fuse_pass", // - "seq_concat_fc_fuse_pass", // - "squeeze2_matmul_fuse_pass", // - "reshape2_matmul_fuse_pass", // - "flatten2_matmul_fuse_pass", // - "map_matmul_v2_to_mul_pass", // - // "map_matmul_v2_to_matmul_pass", // + "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "squeeze2_matmul_fuse_pass", // + "reshape2_matmul_fuse_pass", // + "flatten2_matmul_fuse_pass", // + "map_matmul_v2_to_mul_pass", // + "map_matmul_v2_to_matmul_pass", // "map_matmul_to_mul_pass", // "fc_fuse_pass", // "repeated_fc_relu_fuse_pass", // diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 4e435660ff6dc48ac6d7f7da788d07bbad3b6a89..051f97ad4ec8de8a56407e13c7221e6f0e4d1046 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -336,6 +336,8 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, "The Input(%s) has not been initialized properly. The " "shape of Input(%s) = [%s].", dim)); + + // if mkldnn reshape+transpose+matmul fuse activated if (!shape.empty() && !axis.empty()) { PADDLE_ENFORCE_GE( shape.size(), 2, @@ -355,6 +357,43 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, "Ranks of shape_%s and axis_%s attributes of MatMulOp " "must be equal.", input_name, input_name)); + + int num_negative = std::count(shape.begin(), shape.end(), -1); + PADDLE_ENFORCE_LE(num_negative, 1, + platform::errors::InvalidArgument( + "The max number of -1 in fused_reshape_%s is 1 " + "but received %d.", + input_name, num_negative)); + + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT(i, dim.size(), + platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", + input_name, i, dim.size())); + shape[i] = dim.at(i); + } + } + } + + // if "-1" is present then one of reshape dims must be infered + auto it_negative = std::find(shape.begin(), shape.end(), -1); + if (it_negative != shape.end()) { + int64_t dim_product = 1; + for (int i = 0; i < dim.size(); i++) { + dim_product *= dim.at(i); + } + + int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, + std::multiplies()); + int index = std::distance(shape.begin(), it_negative); + shape[index] = dim_product / shape_product; + } + dim = dim.reshape(shape).transpose(axis); } return dim; diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index b78acd32e6dc8fffb6be48679164fae565a111bf..b7eb5a3ab4b57cf4d29a7f6f8b3d1ff65d9a330d 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -245,6 +245,36 @@ class MatMulMKLDNNHandler auto input_dims = ctx.Input(input_name)->dims(); auto new_dims = input_dims; if (!shape.empty() && !axis.empty()) { + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT( + i, input_dims.size(), + paddle::platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", input_name, + i, input_dims.size())); + shape[i] = input_dims.at(i); + } + } + } + + // if "-1" is present then one of reshape dims must be infered + auto it_negative = std::find(shape.begin(), shape.end(), -1); + if (it_negative != shape.end()) { + int64_t dim_product = 1; + for (int i = 0; i < input_dims.size(); i++) { + dim_product *= input_dims.at(i); + } + + int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, + std::multiplies()); + int index = std::distance(shape.begin(), it_negative); + shape[index] = dim_product / shape_product; + } + new_dims = input_dims.reshape(shape).transpose(axis); }