diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc index ce892aa86838a0ea36e5a41da05b4cc85cd2a8e3..a71961837681cc45f6438b48e495d3c10dc9e742 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc @@ -66,27 +66,23 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse( auto transpose_axis = PADDLE_GET_CONST(std::vector, transpose_op->Op()->GetAttr("axis")); - auto reshape_out_size = reshape_shape.size(); - auto transpose_out_size = transpose_axis.size(); const std::vector supported_axis{0, 2, 1, 3}; - const bool supported_transpose_axis = std::equal( - transpose_axis.begin(), transpose_axis.end(), supported_axis.begin()); - if (transpose_out_size != 4) { - VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: " - << "supported rank is 4, received " << transpose_out_size; - return; - } - if (!supported_transpose_axis) { + if (transpose_axis != supported_axis) { VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: " << "supported transpose axis for the fuse are {0, 2, 1, 3}"; return; } - if (reshape_out_size != 3) { + if (reshape_shape.size() != 3) { VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: " << "reshape_out supported rank is 3, received " - << reshape_out_size; + << reshape_shape.size(); return; } + if (std::count(reshape_shape.begin(), reshape_shape.end(), -1) > 1) { + VLOG(3) << "Only one dim can be undefined / marked as '-1'"; + return; + } + OpDesc *matmul_desc = matmul_op->Op(); matmul_desc->SetOutput("Out", {reshape_out->Name()}); matmul_desc->SetAttr("fused_reshape_Out", reshape_shape); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc index 29e013c55a40bd8d297887f8625d4ad627ae8153..bb6ceb6064c638078bb6518c194b4f2c9fc8bd94 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc @@ -23,14 +23,24 @@ namespace ir { void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const { auto matmul_types = {"matmul", "matmul_v2"}; - bool with_reshape_xshape = true; - bool with_transpose_xshape = true; for (const auto &matmul_type : matmul_types) { - Fuse(graph, matmul_type, with_reshape_xshape, with_transpose_xshape); - Fuse(graph, matmul_type, with_reshape_xshape, !with_transpose_xshape); - Fuse(graph, matmul_type, !with_reshape_xshape, with_transpose_xshape); - Fuse(graph, matmul_type, !with_reshape_xshape, !with_transpose_xshape); + Fuse(graph, + matmul_type, + false /*with_reshape_xshape*/, + false /*with_transpose_xshape*/); + Fuse(graph, + matmul_type, + false /*with_reshape_xshape*/, + true /*with_transpose_xshape*/); + Fuse(graph, + matmul_type, + true /*with_reshape_xshape*/, + false /*with_transpose_xshape*/); + Fuse(graph, + matmul_type, + true /*with_reshape_xshape*/, + true /*with_transpose_xshape*/); } } @@ -80,29 +90,44 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, rtm_pattern); - auto reshape_shape = - paddle::get>(reshape_op->Op()->GetAttr("shape")); - auto transpose_axis = - paddle::get>(transpose_op->Op()->GetAttr("axis")); - OpDesc *matmul_desc = matmul_op->Op(); std::string input_var_name = transpose_out->Name(); - - auto UpdateMatmul = [&](std::string matmul_input_name) { - matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()}); - matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape); - matmul_desc->SetAttr("fused_transpose_" + matmul_input_name, - transpose_axis); - }; + std::string matmul_input_name; if (matmul_desc->Inputs().at("X").at(0) == input_var_name) { - UpdateMatmul("X"); + matmul_input_name = "X"; } else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) { - UpdateMatmul("Y"); + matmul_input_name = "Y"; } else { throw platform::errors::InvalidArgument("Unexpected input to " + matmul_type + " encountered."); } + auto reshape_shape = + paddle::get>(reshape_op->Op()->GetAttr("shape")); + auto transpose_axis = + paddle::get>(transpose_op->Op()->GetAttr("axis")); + + if (reshape_shape.size() < 2 || reshape_shape.size() > 4) { + VLOG(3) << "shape_" + matmul_input_name + " attribute of " + matmul_type + + " was implemented for 2, 3 or 4 dimensions."; + return; + } + if (reshape_shape.size() != transpose_axis.size()) { + VLOG(3) << "Ranks of shape_" + matmul_input_name + " and axis_" + + matmul_input_name + "attributes of " + matmul_type + + " must be equal."; + return; + } + if (std::count(reshape_shape.begin(), reshape_shape.end(), -1) > 1) { + VLOG(3) << "Only one dim can be undefined / marked as '-1'"; + return; + } + + matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()}); + matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape); + matmul_desc->SetAttr("fused_transpose_" + matmul_input_name, + transpose_axis); + std::unordered_set nodes_to_remove{ reshape_op, reshape_out, transpose_op, transpose_out}; if (with_reshape_xshape) nodes_to_remove.insert(reshape_xshape); diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index ff7ab502e8efefeb3235977fddc59430f1e456a7..80018ddb1c9c2b83f9cb1d6cfefc3666e84ad258 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -358,58 +358,7 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, "shape of Input(%s) = [%s].", dim)); - // if mkldnn reshape+transpose+matmul fuse activated if (!shape.empty() && !axis.empty()) { - PADDLE_ENFORCE_GE( - shape.size(), - 2, - platform::errors::InvalidArgument( - "shape_%s attribute of MatMulOp was implemented for 2, 3 " - "or 4 dimensions.", - input_name)); - PADDLE_ENFORCE_LE( - shape.size(), - 4, - platform::errors::InvalidArgument( - "shape_%s attribute of MatMulOp was implemented for 2, 3 " - "or 4 dimensions.", - input_name)); - PADDLE_ENFORCE_EQ( - shape.size(), - axis.size(), - platform::errors::InvalidArgument( - "Ranks of shape_%s and axis_%s attributes of MatMulOp " - "must be equal.", - input_name, - input_name)); - - int num_negative = std::count(shape.begin(), shape.end(), -1); - PADDLE_ENFORCE_LE(num_negative, - 1, - platform::errors::InvalidArgument( - "The max number of -1 in fused_reshape_%s is 1 " - "but received %d.", - input_name, - num_negative)); - - auto it_zero = std::find(shape.begin(), shape.end(), 0); - if (it_zero != shape.end()) { - for (uint64_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - PADDLE_ENFORCE_LT(i, - dim.size(), - platform::errors::InvalidArgument( - "The index of 0 in fused_reshape_%s ", - "should be less than output dim size, ", - "but the index is %d and output dim size is %d", - input_name, - i, - dim.size())); - shape[i] = dim.at(i); - } - } - } - dim = dim.reshape(shape).transpose(axis); } return dim; @@ -732,14 +681,11 @@ class MatMulOp : public framework::OperatorWithKernel { framework::DDim ddim_out = phi::make_ddim(dim_out); #ifdef PADDLE_WITH_MKLDNN - // if mkldnn matmul+transpose+reshape fuse activated - auto reshape_out = - context->Attrs().Get>("fused_reshape_Out"); - auto transpose_out = - context->Attrs().Get>("fused_transpose_Out"); - - if (!reshape_out.empty() && !transpose_out.empty()) { - ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out); + auto shape = context->Attrs().Get>("fused_reshape_Out"); + auto axis = context->Attrs().Get>("fused_transpose_Out"); + + if (!shape.empty() && !axis.empty()) { + ddim_out = ddim_out.transpose(axis).reshape(shape); } #endif context->SetOutputDim("Out", ddim_out); diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 29e98092f003ad33f8505c636870716ff892ec28..8c045630afb4d34cb981e8bfb7710ec70c4e66a9 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -38,58 +38,7 @@ static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx, "shape of Input(%s) = [%s].", dim)); - // if mkldnn reshape+transpose+matmul fuse activated if (!shape.empty() && !axis.empty()) { - PADDLE_ENFORCE_GE( - shape.size(), - 2, - platform::errors::InvalidArgument( - "shape_%s attribute of MatMulOp was implemented for 2, 3 " - "or 4 dimensions.", - input_name)); - PADDLE_ENFORCE_LE( - shape.size(), - 4, - platform::errors::InvalidArgument( - "shape_%s attribute of MatMulOp was implemented for 2, 3 " - "or 4 dimensions.", - input_name)); - PADDLE_ENFORCE_EQ( - shape.size(), - axis.size(), - platform::errors::InvalidArgument( - "Ranks of shape_%s and axis_%s attributes of MatMulOp " - "must be equal.", - input_name, - input_name)); - - int num_negative = std::count(shape.begin(), shape.end(), -1); - PADDLE_ENFORCE_LE(num_negative, - 1, - platform::errors::InvalidArgument( - "The max number of -1 in fused_reshape_%s is 1 " - "but received %d.", - input_name, - num_negative)); - - auto it_zero = std::find(shape.begin(), shape.end(), 0); - if (it_zero != shape.end()) { - for (uint64_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - PADDLE_ENFORCE_LT(i, - dim.size(), - platform::errors::InvalidArgument( - "The index of 0 in fused_reshape_%s ", - "should be less than output dim size, ", - "but the index is %d and output dim size is %d", - input_name, - i, - dim.size())); - shape[i] = dim.at(i); - } - } - } - dim = dim.reshape(shape).transpose(axis); } return dim; @@ -169,13 +118,11 @@ class MatMulV2Op : public framework::OperatorWithKernel { auto ddim_out = phi::make_ddim(new_dims); #ifdef PADDLE_WITH_MKLDNN - // if mkldnn matmul_v2+transpose+reshape fuse activated - auto reshape_out = ctx->Attrs().Get>("fused_reshape_Out"); - auto transpose_out = - ctx->Attrs().Get>("fused_transpose_Out"); + auto shape = ctx->Attrs().Get>("fused_reshape_Out"); + auto axis = ctx->Attrs().Get>("fused_transpose_Out"); - if (!reshape_out.empty() && !transpose_out.empty()) { - ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out); + if (!shape.empty() && !axis.empty()) { + ddim_out = ddim_out.transpose(axis).reshape(shape); } #endif diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 2f9fa210e225a0166dceb440a42318bf0d68d534..f8c9c9d86a9953231424ef53157123a35275cc78 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -108,24 +108,6 @@ phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) { auto axis = ctx.Attr>("fused_transpose_" + input_name); auto input_dims = ctx.Input(input_name)->dims(); if (!shape.empty() && !axis.empty()) { - auto it_zero = std::find(shape.begin(), shape.end(), 0); - if (it_zero != shape.end()) { - for (uint64_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - PADDLE_ENFORCE_LT(i, - input_dims.size(), - paddle::platform::errors::InvalidArgument( - "The index of 0 in fused_reshape_%s ", - "should be less than output dim size, ", - "but the index is %d and output dim size is %d", - input_name, - i, - input_dims.size())); - shape[i] = input_dims.at(i); - } - } - } - return input_dims.reshape(shape).transpose(axis); } return input_dims; @@ -225,12 +207,7 @@ class MatMulMKLDNNHandler src_memory_p->set_data_handle(x_ptr); weights_memory_p->set_data_handle(y_ptr); dst_memory_p->set_data_handle(out_ptr); - matmul_p->execute(astream, - { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}, - }); + matmul_p->execute(astream, matmul_args); x_ptr = static_cast(x_ptr) + std::get<0>(offsets); y_ptr = static_cast(y_ptr) + std::get<1>(offsets); out_ptr = static_cast(out_ptr) + std::get<2>(offsets); @@ -270,25 +247,6 @@ class MatMulMKLDNNHandler auto input_dims = ctx.Input(input_name)->dims(); auto new_dims = input_dims; if (!shape.empty() && !axis.empty()) { - auto it_zero = std::find(shape.begin(), shape.end(), 0); - if (it_zero != shape.end()) { - for (uint64_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - PADDLE_ENFORCE_LT( - i, - input_dims.size(), - paddle::platform::errors::InvalidArgument( - "The index of 0 in fused_reshape_%s ", - "should be less than output dim size, ", - "but the index is %d and output dim size is %d", - input_name, - i, - input_dims.size())); - shape[i] = input_dims.at(i); - } - } - } - new_dims = input_dims.reshape(shape).transpose(axis); } @@ -519,7 +477,10 @@ static void ExecuteMatMul(const ExecutionContext &ctx) { constexpr bool is_int8 = IsInt8(); constexpr bool is_bfloat16 = IsBfloat16(); const bool force_fp32_output = ctx.Attr("force_fp32_output"); - constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses + const bool fuse_relu = + ctx.HasAttr("fuse_activation") + ? ctx.Attr("fuse_activation") == "relu" + : false; auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *out = ctx.Output("Out"); @@ -596,23 +557,6 @@ std::vector GetInputStrides(const ExecutionContext &ctx, auto input_dims = ctx.Input(input_name)->dims(); auto new_dims = input_dims; if (!shape.empty() && !axis.empty()) { - auto it_zero = std::find(shape.begin(), shape.end(), 0); - if (it_zero != shape.end()) { - for (uint64_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - PADDLE_ENFORCE_LT(i, - input_dims.size(), - paddle::platform::errors::InvalidArgument( - "The index of 0 in fused_reshape_%s ", - "should be less than output dim size, ", - "but the index is %d and output dim size is %d", - input_name, - i, - input_dims.size())); - shape[i] = input_dims.at(i); - } - } - } new_dims = input_dims.reshape(shape).transpose(axis); } diff --git a/paddle/phi/core/ddim.cc b/paddle/phi/core/ddim.cc index 1809c413bc146caf4c1a94d92d52a5fd5ccc96f5..18778c9abf60f6051599fe4419b8029ac8b812c1 100644 --- a/paddle/phi/core/ddim.cc +++ b/paddle/phi/core/ddim.cc @@ -172,10 +172,13 @@ DDim stride_numel(const DDim& ddim) { } DDim DDim::reshape(std::vector& shape) const { - const int64_t copy_dim_val = 0; const DDim& in_dims = *this; - DDim out_dims; - out_dims.rank_ = shape.size(); + + for (uint64_t i = 0; i < shape.size(); ++i) { + if (shape[i] == 0) { + shape[i] = in_dims.at(i); + } + } // Dim marked as "-1" must be inferred auto it = std::find(shape.begin(), shape.end(), -1); @@ -186,54 +189,14 @@ DDim DDim::reshape(std::vector& shape) const { shape[index] = product(in_dims) / reshape_out_product; } - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == copy_dim_val) { - PADDLE_ENFORCE_LT(static_cast(i), - in_dims.size(), - phi::errors::InvalidArgument( - "Index %d of shape under which the value of 0 " - "is stored, must be lower than the number of " - "old dimensions. But received shape[%d] = 0, " - "dimensions = %d, shape = [%s].", - i, - in_dims.size(), - in_dims)); - out_dims[i] = in_dims[i]; - } else { - out_dims[i] = shape[i]; - } - } - return out_dims; + return phi::make_ddim(shape); } DDim DDim::transpose(const std::vector& axis) const { const DDim& in_dims = *this; - size_t in_rank = in_dims.size(); - size_t axis_size = axis.size(); - - auto axis_set = std::set(axis.begin(), axis.end()); - PADDLE_ENFORCE_EQ(axis_set.size(), - axis_size, - phi::errors::InvalidArgument( - "In an axis array, elements must be unique.")); - - PADDLE_ENFORCE_EQ( - in_rank, - axis_size, - phi::errors::InvalidArgument("The input dimension's size " - "should be equal to the axis's size. " - "But received dimension is %d, " - "axis's size is %d", - in_rank, - axis_size)); - - PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), - axis_size, - phi::errors::InvalidArgument( - "Axis values must be ranging from 0 to (dims - 1).")); DDim out_dims(in_dims); - for (size_t i = 0; i < axis_size; i++) { + for (size_t i = 0; i < axis.size(); i++) { out_dims[i] = in_dims[axis[i]]; } return out_dims; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py index 20028fb335b8fc1d3d26a451620a4b0ce2f3b786..b894fc708b4243b802788981d7c7a34f43d6ad0d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py @@ -31,9 +31,9 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): input_dim = draw(st.sampled_from([32])) activation_type = draw( st.sampled_from([ - 'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt', - 'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', - 'hard_sigmoid', 'leaky_relu' + 'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish', + 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid', + 'leaky_relu' ])) def generate_input(type): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py index 2858d7f2d4e33058e08f2edea96153f7d94214c0..153b81fa797af560fa56898db6c6a1ce54719215 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py @@ -30,9 +30,9 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): input_dim = draw(st.sampled_from([16, 32, 64])) activation_type = draw( st.sampled_from([ - 'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt', - 'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', - 'hard_sigmoid', 'leaky_relu' + 'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish', + 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid', + 'leaky_relu' ])) def generate_input(type):