未验证 提交 5cfe1645 编写于 作者: S Sławomir Siwek 提交者: GitHub

Replace matmul(v2) with fused_matmul during oneDNN fuse passes (#49515)

* replace matmul with matmul_v2 in fuse passes

* Remove fusion logic from matmul

* removing fusion methods

* add proper name

* adjust namespaces

* clean attrs in python tests

* delete checkpoint and restore matmul version

* remove unused code

* matmul and reshape/transpose fuses migrated

* split MatmulOneDNN headers

* fuse activation and eltwise_add

* add fuse_activation

* matmul_transpose_reshape/reshape_transpose_matmul

* matmul + elementwise_add (fused)

* activation temporary modifciation

* merge newest develop

* remove depedency from other PR

* revert pbtxt

* remove placeholders from matmul_v2

* add description in OPMaker

* remove matmul_v2_op.h and all depedencies

* remove dims changing in base op

* add possibility to fuse already fused_matmul

* restart broken CI

* Empty-Commit

* revert matmul_utils.h

* codestyle

* adjust imports

* add pbtxt file

* 100% matmul unit tests coverage

* trigger CI with minimal changes to develop

* adjust changes to develop

* add fused_matmul op

* inherit base ops

* add "v2"

* move OPMaker

* Gradually add fused_matmul files

* second batch of fused_matmul changes

* split infershapes of matmul_v2 and fused_matmul

* inherit fused_matmul from matmul_v2

* Update paddle/phi/backends/onednn/onednn_reuse.h
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

* Update paddle/phi/kernels/fusion/onednn/fused_matmul_kernel.cc
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

---------
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>
上级 aa8cef4a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -26,7 +26,7 @@ using string::PrettyLogDetail; ...@@ -26,7 +26,7 @@ using string::PrettyLogDetail;
void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = GetSupportedActivations(); auto act_types = GetSupportedActivations();
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"fused_matmul", "matmul", "matmul_v2"};
for (const auto& matmul_type : matmul_types) for (const auto& matmul_type : matmul_types)
for (auto& act_type : act_types) { for (auto& act_type : act_types) {
...@@ -61,8 +61,17 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct( ...@@ -61,8 +61,17 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, matmul_act_pattern); activation_out, activation_out, matmul_act_pattern);
SetActivationAttrs(matmul->Op(), activation->Op(), act_type); OpDesc* matmul_op = matmul->Op();
matmul->Op()->SetOutput("Out", {activation_out->Name()});
matmul_op->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_op->SetAttr("trans_x", matmul_op->GetAttr("transpose_X"));
matmul_op->SetAttr("trans_y", matmul_op->GetAttr("transpose_Y"));
matmul_op->SetAttr("matmul_alpha", matmul_op->GetAttr("alpha"));
}
SetActivationAttrs(matmul_op, activation->Op(), act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});
IR_NODE_LINK_TO(matmul, activation_out); IR_NODE_LINK_TO(matmul, activation_out);
GraphSafeRemoveNodes(graph, {activation, matmul_out}); GraphSafeRemoveNodes(graph, {activation, matmul_out});
...@@ -88,11 +97,6 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { ...@@ -88,11 +97,6 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.AddInput("Y") .AddInput("Y")
.IsTensor() .IsTensor()
.End() .End()
.AddInput(
"ResidualData") // Extra tensor used in matmul+elementwise_add fuse
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End() .End()
...@@ -113,8 +117,24 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { ...@@ -113,8 +117,24 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.AddInput("Y") .AddInput("Y")
.IsTensor() .IsTensor()
.End() .End()
.AddInput( .AddOutput("Out")
"ResidualData") // Extra tensor used in matmul+elementwise_add fuse .IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("fused_matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor() .IsTensor()
.IsOptional() .IsOptional()
.End() .End()
...@@ -126,6 +146,50 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { ...@@ -126,6 +146,50 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.End() .End()
.AddAttr("trans_y") .AddAttr("trans_y")
.IsType<bool>() .IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End(); .End();
AddOpCompat(OpCompat("abs")) AddOpCompat(OpCompat("abs"))
...@@ -279,6 +343,7 @@ REGISTER_PASS(matmul_activation_mkldnn_fuse_pass, ...@@ -279,6 +343,7 @@ REGISTER_PASS(matmul_activation_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fused_matmul", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.EQ("abs", 0) .EQ("abs", 0)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void MatmulElementwiseAddMKLDNNFusePass::ApplyImpl(Graph* graph) const { void MatmulElementwiseAddMKLDNNFusePass::ApplyImpl(Graph* graph) const {
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"fused_matmul", "matmul", "matmul_v2"};
auto matmul_as_x = {true, false}; auto matmul_as_x = {true, false};
for (const auto& matmul_type : matmul_types) for (const auto& matmul_type : matmul_types)
...@@ -65,6 +65,12 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd( ...@@ -65,6 +65,12 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd(
return; return;
} }
matmul->Op()->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul->Op()->SetAttr("trans_x", matmul->Op()->GetAttr("transpose_X"));
matmul->Op()->SetAttr("trans_y", matmul->Op()->GetAttr("transpose_Y"));
matmul->Op()->SetAttr("matmul_alpha", matmul->Op()->GetAttr("alpha"));
}
matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()}); matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()});
matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()}); matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()});
...@@ -125,6 +131,71 @@ MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() { ...@@ -125,6 +131,71 @@ MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() {
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("fused_matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -149,6 +220,7 @@ REGISTER_PASS(matmul_elementwise_add_mkldnn_fuse_pass, ...@@ -149,6 +220,7 @@ REGISTER_PASS(matmul_elementwise_add_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(matmul_elementwise_add_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(matmul_elementwise_add_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fused_matmul", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.LE("elementwise_add", 1)); .LE("elementwise_add", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -24,7 +24,7 @@ namespace ir { ...@@ -24,7 +24,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(Graph *graph) const { void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(Graph *graph) const {
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"fused_matmul", "matmul", "matmul_v2"};
for (const auto &matmul_type : matmul_types) { for (const auto &matmul_type : matmul_types) {
Fuse(graph, matmul_type); Fuse(graph, matmul_type);
...@@ -84,6 +84,12 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse( ...@@ -84,6 +84,12 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse(
} }
OpDesc *matmul_desc = matmul_op->Op(); OpDesc *matmul_desc = matmul_op->Op();
matmul_desc->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
matmul_desc->SetAttr("matmul_alpha", matmul_desc->GetAttr("alpha"));
}
matmul_desc->SetOutput("Out", {reshape_out->Name()}); matmul_desc->SetOutput("Out", {reshape_out->Name()});
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape); matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
matmul_desc->SetAttr("fused_transpose_Out", transpose_axis); matmul_desc->SetAttr("fused_transpose_Out", transpose_axis);
...@@ -149,6 +155,71 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { ...@@ -149,6 +155,71 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("fused_matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End();
AddOpCompat(OpCompat("transpose2")) AddOpCompat(OpCompat("transpose2"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -189,6 +260,7 @@ REGISTER_PASS(matmul_transpose_reshape_mkldnn_fuse_pass, ...@@ -189,6 +260,7 @@ REGISTER_PASS(matmul_transpose_reshape_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fused_matmul", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.EQ("transpose2", 0) .EQ("transpose2", 0)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -27,6 +27,7 @@ using string::PrettyLogDetail; ...@@ -27,6 +27,7 @@ using string::PrettyLogDetail;
void FuseOperatorScaleOneDNNPass::ApplyImpl(Graph *graph) const { void FuseOperatorScaleOneDNNPass::ApplyImpl(Graph *graph) const {
const std::vector<std::string> fusable_ops{ const std::vector<std::string> fusable_ops{
"fc", "fc",
"fused_matmul",
"matmul", "matmul",
"matmul_v2", "matmul_v2",
"elementwise_add", "elementwise_add",
...@@ -85,6 +86,19 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph, ...@@ -85,6 +86,19 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph,
scale = *(scale_tensor->data<float>()); scale = *(scale_tensor->data<float>());
} }
if (op_type == "matmul") {
operator_op->Op()->SetType("fused_matmul");
operator_op->Op()->SetAttr("trans_x",
operator_op->Op()->GetAttr("transpose_X"));
operator_op->Op()->SetAttr("trans_y",
operator_op->Op()->GetAttr("transpose_Y"));
operator_op->Op()->SetAttr("matmul_alpha",
operator_op->Op()->GetAttr("alpha"));
}
if (op_type == "matmul_v2") {
operator_op->Op()->SetType("fused_matmul");
}
operator_op->Op()->SetAttr("fused_output_scale", scale); operator_op->Op()->SetAttr("fused_output_scale", scale);
operator_op->Op()->SetOutput("Out", {scale_out->Name()}); operator_op->Op()->SetOutput("Out", {scale_out->Name()});
...@@ -111,6 +125,7 @@ REGISTER_PASS_CAPABILITY(operator_scale_onednn_fuse_pass) ...@@ -111,6 +125,7 @@ REGISTER_PASS_CAPABILITY(operator_scale_onednn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc", 0) .EQ("fc", 0)
.EQ("fused_matmul", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.LE("elementwise_add", 1) .LE("elementwise_add", 1)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -22,7 +22,7 @@ namespace framework { ...@@ -22,7 +22,7 @@ namespace framework {
namespace ir { namespace ir {
void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const { void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const {
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"matmul", "matmul_v2", "fused_matmul"};
for (const auto &matmul_type : matmul_types) { for (const auto &matmul_type : matmul_types) {
Fuse(graph, Fuse(graph,
...@@ -102,6 +102,25 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -102,6 +102,25 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
matmul_type + " encountered."); matmul_type + " encountered.");
} }
// Return if input of fused_matmul is already fused
if (matmul_type == "fused_matmul") {
auto is_already_fused_X =
matmul_desc->HasAttr("fused_reshape_X")
? !(PADDLE_GET_CONST(std::vector<int>,
matmul_desc->GetAttr("fused_reshape_X"))
.empty())
: false;
if (is_already_fused_X && matmul_input_name == "X") return;
auto is_already_fused_Y =
matmul_desc->HasAttr("fused_reshape_Y")
? !(PADDLE_GET_CONST(std::vector<int>,
matmul_desc->GetAttr("fused_reshape_Y"))
.empty())
: false;
if (is_already_fused_Y && matmul_input_name == "Y") return;
}
auto reshape_shape = auto reshape_shape =
paddle::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape")); paddle::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
auto transpose_axis = auto transpose_axis =
...@@ -123,6 +142,12 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -123,6 +142,12 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
return; return;
} }
matmul_desc->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
matmul_desc->SetAttr("matmul_alpha", matmul_desc->GetAttr("alpha"));
}
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()}); matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape); matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name, matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
...@@ -220,6 +245,71 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { ...@@ -220,6 +245,71 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
.AddAttr("trans_y") .AddAttr("trans_y")
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("fused_matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End();
} }
} // namespace ir } // namespace ir
...@@ -234,5 +324,6 @@ REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_mkldnn_fuse_pass) ...@@ -234,5 +324,6 @@ REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_mkldnn_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0) .EQ("reshape2", 0)
.EQ("transpose2", 0) .EQ("transpose2", 0)
.EQ("fused_matmul", 0)
.EQ("matmul", 1) .EQ("matmul", 1)
.EQ("matmul_v2", 0)); .EQ("matmul_v2", 0));
type: "fused_matmul"
def {
inputs {
name: "X"
}
inputs {
name: "Y"
}
inputs {
name: "ResidualData"
}
outputs {
name: "Out"
}
attrs {
name: "trans_x"
type: BOOLEAN
}
attrs {
name: "trans_y"
type: BOOLEAN
}
}
extra {
attrs {
name: "matmul_alpha"
type: FLOAT
}
attrs {
name: "fuse_activation"
type: STRING
}
attrs {
name: "fuse_alpha"
type: FLOAT
}
attrs {
name: "fuse_beta"
type: FLOAT
}
attrs {
name: "fused_output_scale"
type: FLOAT
}
attrs {
name: "fused_reshape_X"
type: INTS
}
attrs {
name: "fused_transpose_X"
type: INTS
}
attrs {
name: "fused_reshape_Y"
type: INTS
}
attrs {
name: "fused_transpose_Y"
type: INTS
}
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "Scale_x"
type: FLOAT
}
attrs {
name: "Scale_y"
type: FLOAT
}
attrs {
name: "Scale_in_eltwise"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "force_fp32_output"
type: BOOLEAN
}
}
...@@ -39,28 +39,4 @@ extra { ...@@ -39,28 +39,4 @@ extra {
name: "op_device" name: "op_device"
type: STRING type: STRING
} }
attrs {
name: "fused_reshape_X"
type: INTS
}
attrs {
name: "fused_reshape_Y"
type: INTS
}
attrs {
name: "fused_transpose_X"
type: INTS
}
attrs {
name: "fused_transpose_Y"
type: INTS
}
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
} }
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/matmul_v2_op.h"
namespace paddle {
namespace operators {
static std::vector<int64_t> GetInputShape(phi::DDim dim,
std::vector<int> shape,
std::vector<int> axis) {
PADDLE_ENFORCE_GT(dim.size(),
0,
phi::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));
auto is_input_fused = (!shape.empty() && !axis.empty());
if (is_input_fused) {
dim = dim.reshape(shape).transpose(axis);
}
return phi::vectorize(dim);
}
class FusedMatmulOp : public MatMulV2Op {
public:
using MatMulV2Op::MatMulV2Op;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fused_matmul");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fused_matmul");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fused_matmul");
bool trans_x = ctx->Attrs().Get<bool>("trans_x");
bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x =
GetInputShape(ctx->GetInputDim("X"),
ctx->Attrs().Get<std::vector<int>>("fused_reshape_X"),
ctx->Attrs().Get<std::vector<int>>("fused_transpose_X"));
std::vector<int64_t> dims_y =
GetInputShape(ctx->GetInputDim("Y"),
ctx->Attrs().Get<std::vector<int>>("fused_reshape_Y"),
ctx->Attrs().Get<std::vector<int>>("fused_transpose_Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0,
phi::errors::InvalidArgument(
"The Input(X) dims size must be greater than 0,"
" but received dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0,
phi::errors::InvalidArgument(
"The Input(Y) dims size must be greater than 0,"
" but received dims size is 0. "));
bool x_broadcasted = false;
bool y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto ddim_out = phi::make_ddim(new_dims);
auto shape = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
auto is_output_fused = (!shape.empty() && !axis.empty());
if (is_output_fused) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
ctx->SetOutputDim("Out", ddim_out);
ctx->ShareLoD("X", "Out");
}
};
class FusedMatmulOpMaker : public MatMulV2OpMaker {
protected:
void Apply() override {
AddInput("ResidualData",
"Extra input from matmul_elementwise_add_mkldnn_fuse_pass")
.AsDispensable()
.AsExtra();
AddAttr<float>("matmul_alpha", "Output scale used in matmul_v1")
.SetDefault(1.0f);
AddAttr<std::string>(
"fuse_activation",
"Activation type from matmul_activation_mkldnn_fuse_pass")
.SetDefault("");
AddAttr<float>("fuse_alpha",
"Activation alpha from matmul_activation_mkldnn_fuse_pass")
.SetDefault(0.0f);
AddAttr<float>("fuse_beta",
"Activation beta from matmul_activation_mkldnn_fuse_pass")
.SetDefault(0.0f);
AddAttr<float>("fused_output_scale",
"Output scale from operator_scale_onednn_fuse_pass")
.SetDefault(1.0f);
AddAttr<std::vector<int>>("fused_reshape_X",
"Reshape's shape attribute from "
"reshape_transpose_matmul_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_transpose_X",
"Transpose's axis attribute from "
"reshape_transpose_matmul_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_reshape_Y",
"Reshape's shape attribute from "
"reshape_transpose_matmul_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_transpose_Y",
"Transpose's axis attribute from "
"reshape_transpose_matmul_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_reshape_Out",
"Reshape's shape attribute from "
"matmul_transpose_reshape_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_transpose_Out",
"Transpose's axis attribute from "
"matmul_transpose_reshape_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::string>("mkldnn_data_type", "oneDNN operator data type")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"});
AddAttr<float>("Scale_x", "Matmul X input quantization scale")
.SetDefault(1.0f);
AddAttr<float>("Scale_y", "Matmul Y input quantization scale")
.SetDefault(1.0f);
AddAttr<float>("Scale_in_eltwise", "Matmul ResidualData quantization scale")
.SetDefault(0.0f);
AddAttr<float>("Scale_out", "Matmul output quantization scale")
.SetDefault(1.0f);
AddAttr<bool>("force_fp32_output",
"Flag determining if output should be converted to FP32")
.SetDefault(false);
AddComment(
R"DOC(Matrix multiplication extended with oneDNN-specific fusion logic.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_matmul,
ops::FusedMatmulOp,
ops::FusedMatmulOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -24,168 +24,131 @@ ...@@ -24,168 +24,131 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx, void MatMulV2Op::InferShape(framework::InferShapeContext* ctx) const {
const std::string input_name) { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matmul_v2");
auto shape = ctx.Attrs().Get<std::vector<int>>("fused_reshape_" + input_name); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "matmul_v2");
auto axis = OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matmul_v2");
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name); bool trans_x = ctx->Attrs().Get<bool>("trans_x");
auto dim = ctx.GetInputDim(input_name); bool trans_y = ctx->Attrs().Get<bool>("trans_y");
PADDLE_ENFORCE_GT(dim.size(), std::vector<int64_t> dims_x = phi::vectorize(ctx->GetInputDim("X"));
std::vector<int64_t> dims_y = phi::vectorize(ctx->GetInputDim("Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The " "The Input(X) dims size must be greater than 0,"
"shape of Input(%s) = [%s].", " but received dims size is 0. "));
dim)); PADDLE_ENFORCE_GT(ndims_y,
0,
if (!shape.empty() && !axis.empty()) { phi::errors::InvalidArgument(
dim = dim.reshape(shape).transpose(axis); "The Input(Y) dims size must be greater than 0,"
} " but received dims size is 0. "));
return dim;
}
class MatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matmul_v2");
bool trans_x = ctx->Attrs().Get<bool>("trans_x");
bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x = phi::vectorize(GetDimForInput(*ctx, "X"));
std::vector<int64_t> dims_y = phi::vectorize(GetDimForInput(*ctx, "Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0,
platform::errors::InvalidArgument(
"The Input(X) dims size must be greater than 0,"
" but received dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0,
platform::errors::InvalidArgument(
"The Input(Y) dims size must be greater than 0,"
" but received dims size is 0. "));
bool x_broadcasted = false, y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N; bool x_broadcasted = false;
if (trans_x) { bool y_broadcasted = false;
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims; if (ndims_x == 1) {
if (ndims_x > ndims_y) { dims_x.insert(dims_x.begin(), 1);
new_dims.assign(dims_x.begin(), dims_x.end() - 2); ndims_x = 2;
} else if (ndims_x < ndims_y) { x_broadcasted = true;
new_dims.assign(dims_y.begin(), dims_y.end() - 2); }
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto ddim_out = phi::make_ddim(new_dims); if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
#ifdef PADDLE_WITH_MKLDNN size_t M, N;
auto shape = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out"); if (trans_x) {
auto axis = ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out"); M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
if (!shape.empty() && !axis.empty()) { std::vector<int64_t> new_dims;
ddim_out = ddim_out.transpose(axis).reshape(shape); if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
} }
#endif
ctx->SetOutputDim("Out", ddim_out);
ctx->ShareLoD("X", "Out");
} }
if (!x_broadcasted) {
protected: new_dims.push_back(M);
phi::KernelKey GetExpectedKernelType( }
const framework::ExecutionContext& ctx) const override { if (!y_broadcasted) {
auto input_data_type = new_dims.push_back(N);
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); }
return phi::KernelKey(input_data_type, ctx.GetPlace()); if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
} }
phi::KernelKey GetKernelTypeForVar( ctx->SetOutputDim("Out", phi::make_ddim(new_dims));
const std::string& var_name, ctx->ShareLoD("X", "Out");
const phi::DenseTensor& tensor, }
const phi::KernelKey& expected_kernel_type) const override {
if (framework::IsComplexType(expected_kernel_type.dtype())) { phi::KernelKey MatMulV2Op::GetExpectedKernelType(
// only promote inputs’s types when contains complex input const framework::ExecutionContext& ctx) const {
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); auto input_data_type =
} else { OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey MatMulV2Op::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} else {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN // When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN
// op previously) then we also need to rotate shape NHWC -> NCWH // op previously) then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) &&
phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) { phi::DataLayout::kNHWC) {
return phi::KernelKey(tensor.place(),
phi::DataLayout::kNHWC,
expected_kernel_type.dtype());
}
#endif
return phi::KernelKey( return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype()); tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype());
} }
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
}; }
class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { void MatMulV2OpMaker::Make() {
public: AddInput("X", "tensor of shape (d0, d1 ... M, K)");
void Make() override { AddInput("Y", "tensor of shape (d0, d1 ... K, N)");
AddInput("X", "tensor of shape (d0, d1 ... M, K)"); AddOutput("Out", "tensor of shape (d0, d1 ... M, N)");
AddInput("Y", "tensor of shape (d0, d1 ... K, N)"); AddAttr<bool>("trans_x",
AddOutput("Out", "tensor of shape (d0, d1 ... M, N)"); "Set true to transpose the last two dimensions of X before "
AddAttr<bool>("trans_x", "doing multiplication")
"Set true to transpose the last two dimensions of X before " .SetDefault(false);
"doing multiplication") AddAttr<bool>("trans_y",
.SetDefault(false); "Set true to transpose the last two dimensions of Y before "
AddAttr<bool>("trans_y", "doing multiplication")
"Set true to transpose the last two dimensions of Y before " .SetDefault(false);
"doing multiplication") AddComment(
.SetDefault(false); R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K),
AddComment(
R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K),
B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)).
In addition, it also follows the broadcast rule which is similar as In addition, it also follows the broadcast rule which is similar as
numpy.matmul. numpy.matmul.
)DOC"); )DOC");
} Apply();
}; }
class MatMulV2OpGrad : public framework::OperatorWithKernel { class MatMulV2OpGrad : public framework::OperatorWithKernel {
public: public:
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -37,6 +37,29 @@ limitations under the License. */ ...@@ -37,6 +37,29 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override;
};
class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final;
protected:
virtual void Apply() {}
};
// Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
static phi::DenseTensor FoldInitDims(const phi::DenseTensor& input) { static phi::DenseTensor FoldInitDims(const phi::DenseTensor& input) {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -102,12 +102,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet> ...@@ -102,12 +102,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"fused_output_scale", ExtraAttrProperty::ONEDNN}, {"fused_output_scale", ExtraAttrProperty::ONEDNN},
{"fuse_residual_connection", ExtraAttrProperty::ONEDNN}, {"fuse_residual_connection", ExtraAttrProperty::ONEDNN},
{"fuse_with_relu", ExtraAttrProperty::ONEDNN}, {"fuse_with_relu", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Out", ExtraAttrProperty::ONEDNN},
{"fused_transpose_Out", ExtraAttrProperty::ONEDNN},
{"fused_reshape_X", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Y", ExtraAttrProperty::ONEDNN},
{"fused_transpose_X", ExtraAttrProperty::ONEDNN},
{"fused_transpose_Y", ExtraAttrProperty::ONEDNN},
{"mkldnn_data_type", ExtraAttrProperty::ONEDNN}, {"mkldnn_data_type", ExtraAttrProperty::ONEDNN},
{"scale_x", ExtraAttrProperty::ONEDNN}, {"scale_x", ExtraAttrProperty::ONEDNN},
{"scale_y", ExtraAttrProperty::ONEDNN}, {"scale_y", ExtraAttrProperty::ONEDNN},
...@@ -226,8 +220,7 @@ class ExtraInfoUtils { ...@@ -226,8 +220,7 @@ class ExtraInfoUtils {
std::unordered_map<std::string, std::vector<std::string>> std::unordered_map<std::string, std::vector<std::string>>
g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}}, g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}},
{"conv2d_transpose", {"Bias"}}, {"conv2d_transpose", {"Bias"}},
{"conv2d_grad", {"Bias"}}, {"conv2d_grad", {"Bias"}}};
{"matmul_v2", {"ResidualData"}}};
std::vector<std::string> empty_extra_input_names_; std::vector<std::string> empty_extra_input_names_;
}; };
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -53,26 +53,31 @@ constexpr bool is_bfloat16() { ...@@ -53,26 +53,31 @@ constexpr bool is_bfloat16() {
static void AppendActivation(const OneDNNContext& dev_ctx, static void AppendActivation(const OneDNNContext& dev_ctx,
dnnl::post_ops& post_ops, // NOLINT dnnl::post_ops& post_ops, // NOLINT
float activation_scale = 1.0f) { float activation_scale = 1.0f,
const auto invalid_attribute = std::string fuse_activation = "",
dev_ctx.HasDnnAttr("fuse_activation") float fuse_alpha = 0.0f,
? PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation")) float fuse_beta = 0.0f) {
.empty() if (fuse_activation == "") {
: true; const auto invalid_attribute =
if (invalid_attribute) return; dev_ctx.HasDnnAttr("fuse_activation")
? PADDLE_GET_CONST(std::string,
const auto fuse_activation = dev_ctx.GetDnnAttr("fuse_activation"))
dev_ctx.HasDnnAttr("fuse_activation") .empty()
? PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation")) : true;
: ""; if (invalid_attribute) return;
const auto fuse_alpha =
dev_ctx.HasDnnAttr("fuse_alpha") fuse_activation =
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_alpha")) dev_ctx.HasDnnAttr("fuse_activation")
: 0.0f; ? PADDLE_GET_CONST(std::string,
const auto fuse_beta = dev_ctx.GetDnnAttr("fuse_activation"))
dev_ctx.HasDnnAttr("fuse_beta") : "";
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_beta")) fuse_alpha = dev_ctx.HasDnnAttr("fuse_alpha")
: 0.0f; ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_alpha"))
: 0.0f;
fuse_beta = dev_ctx.HasDnnAttr("fuse_beta")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_beta"))
: 0.0f;
}
if (fuse_activation == "hard_sigmoid") { if (fuse_activation == "hard_sigmoid") {
post_ops.append_eltwise(activation_scale, post_ops.append_eltwise(activation_scale,
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -51,8 +51,7 @@ void CalculateMatrixDims(const std::vector<int64_t> &x_dims, ...@@ -51,8 +51,7 @@ void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims, const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims, std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims, std::vector<int64_t> *y_bd_dims,
DenseTensor *out, DenseTensor *out) {
const bool is_output_fused) {
if (x_dims.size() == 1) { if (x_dims.size() == 1) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0]; (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) { } else if (x_dims.size() == 2) {
...@@ -74,7 +73,7 @@ void CalculateMatrixDims(const std::vector<int64_t> &x_dims, ...@@ -74,7 +73,7 @@ void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
} }
} }
if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) { if (x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims()); auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) { for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -121,15 +120,6 @@ void MatmulKernel(const Context &dev_ctx, ...@@ -121,15 +120,6 @@ void MatmulKernel(const Context &dev_ctx,
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false; : false;
bool fuse_relu = false;
if (dev_ctx.HasDnnAttr("fuse_activation")) {
auto act_type =
PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"));
if (act_type == "relu" || act_type == "relu6") {
fuse_relu = true;
}
}
auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X")); auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X"));
auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y")); auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y"));
...@@ -139,12 +129,7 @@ void MatmulKernel(const Context &dev_ctx, ...@@ -139,12 +129,7 @@ void MatmulKernel(const Context &dev_ctx,
std::vector<int64_t> x_bd_dims(ndims, 1); std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1); std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(x_dims, CalculateMatrixDims(x_dims, y_dims, &x_bd_dims, &y_bd_dims, out);
y_dims,
&x_bd_dims,
&y_bd_dims,
out,
funcs::IsOutputFused(dev_ctx));
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
funcs::ExecuteMatmul<T, float>( funcs::ExecuteMatmul<T, float>(
...@@ -152,9 +137,6 @@ void MatmulKernel(const Context &dev_ctx, ...@@ -152,9 +137,6 @@ void MatmulKernel(const Context &dev_ctx,
} else if (is_bfloat16) { } else if (is_bfloat16) {
funcs::ExecuteMatmul<T, paddle::platform::bfloat16>( funcs::ExecuteMatmul<T, paddle::platform::bfloat16>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out); dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else if (fuse_relu) {
funcs::ExecuteMatmul<T, uint8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else { } else {
funcs::ExecuteMatmul<T, int8_t>( funcs::ExecuteMatmul<T, int8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out); dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FusedMatmulOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fused_matmul",
{"X", "Y", "ResidualData"},
{"trans_x",
"trans_y",
"matmul_alpha",
"fuse_activation",
"fuse_alpha",
"fuse_beta",
"fused_output_scale",
"fused_reshape_X",
"fused_transpose_X",
"fused_reshape_Y",
"fused_transpose_Y",
"fused_reshape_Out",
"fused_transpose_Out",
"mkldnn_data_type",
"Scale_x",
"Scale_y",
"Scale_in_eltwise",
"Scale_out",
"force_fp32_output"},
{"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fused_matmul, phi::FusedMatmulOpArgumentMapping);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -103,12 +103,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest): ...@@ -103,12 +103,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
trans_x=transpose_X, trans_x=transpose_X,
trans_y=transpose_Y, trans_y=transpose_Y,
fused_reshape_Out=[],
fused_transpose_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
) )
ops = [ ops = [
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -92,12 +92,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest): ...@@ -92,12 +92,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
trans_x=transpose_X, trans_x=transpose_X,
trans_y=transpose_Y, trans_y=transpose_Y,
fused_reshape_Out=[],
fused_transpose_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
) )
ops = [ ops = [
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -76,12 +76,6 @@ class TestMatmulV2ScaleFusePass(PassAutoScanTest): ...@@ -76,12 +76,6 @@ class TestMatmulV2ScaleFusePass(PassAutoScanTest):
outputs={"Out": ["matmul_out"]}, outputs={"Out": ["matmul_out"]},
trans_x=transpose_X, trans_x=transpose_X,
trans_y=transpose_Y, trans_y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
) )
is_scale_tensor = draw(st.booleans()) is_scale_tensor = draw(st.booleans())
if is_scale_tensor: if is_scale_tensor:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -146,7 +146,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): ...@@ -146,7 +146,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
'operator_scale_onednn_fuse_pass', 'operator_scale_onednn_fuse_pass',
], ],
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -137,7 +137,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest): ...@@ -137,7 +137,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest):
'matmul_activation_mkldnn_fuse_pass', 'matmul_activation_mkldnn_fuse_pass',
], ],
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -76,7 +76,7 @@ class TestMatmulElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -76,7 +76,7 @@ class TestMatmulElementwiseAddMkldnnFusePass(PassAutoScanTest):
config = self.create_inference_config( config = self.create_inference_config(
use_mkldnn=True, passes=['matmul_elementwise_add_mkldnn_fuse_pass'] use_mkldnn=True, passes=['matmul_elementwise_add_mkldnn_fuse_pass']
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -150,7 +150,7 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): ...@@ -150,7 +150,7 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
'operator_scale_onednn_fuse_pass', 'operator_scale_onednn_fuse_pass',
], ],
) )
yield config, ['matmul_v2'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -87,7 +87,7 @@ class TestMatmulV2ElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -87,7 +87,7 @@ class TestMatmulV2ElementwiseAddMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul_v2'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -90,12 +90,6 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -90,12 +90,6 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
attrs={ attrs={
"trans_x": transpose_X, "trans_x": transpose_X,
"trans_y": transpose_Y, "trans_y": transpose_Y,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
) )
...@@ -135,17 +129,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -135,17 +129,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
return program_config return program_config
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
# gpu_cpu_map_matmul_v2_to_matmul_pass will affect the type of final fused op
fused_op = "matmul_v2"
input1_dim1 = program_config.inputs["input_data1"].shape[0]
input2_dim1 = program_config.inputs["input_data2"].shape[0]
input1_dim2 = program_config.inputs["input_data1"].shape[1]
input2_dim2 = program_config.inputs["input_data2"].shape[1]
if input1_dim1 == input2_dim1 and input1_dim2 == input2_dim2:
fused_op = "matmul"
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, [fused_op], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -93,12 +93,6 @@ class TestMkldnnMatmulv2Op(MkldnnAutoScanTest): ...@@ -93,12 +93,6 @@ class TestMkldnnMatmulv2Op(MkldnnAutoScanTest):
attrs={ attrs={
"trans_x": kwargs["transpose_X"], "trans_x": kwargs["transpose_X"],
"trans_y": kwargs["transpose_Y"], "trans_y": kwargs["transpose_Y"],
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
) )
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -116,7 +116,7 @@ class TestOneDNNMatmulTransposeReshapeFusePass(PassAutoScanTest): ...@@ -116,7 +116,7 @@ class TestOneDNNMatmulTransposeReshapeFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -154,7 +154,7 @@ class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest): ...@@ -154,7 +154,7 @@ class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册