未验证 提交 5cfe1645 编写于 作者: S Sławomir Siwek 提交者: GitHub

Replace matmul(v2) with fused_matmul during oneDNN fuse passes (#49515)

* replace matmul with matmul_v2 in fuse passes

* Remove fusion logic from matmul

* removing fusion methods

* add proper name

* adjust namespaces

* clean attrs in python tests

* delete checkpoint and restore matmul version

* remove unused code

* matmul and reshape/transpose fuses migrated

* split MatmulOneDNN headers

* fuse activation and eltwise_add

* add fuse_activation

* matmul_transpose_reshape/reshape_transpose_matmul

* matmul + elementwise_add (fused)

* activation temporary modifciation

* merge newest develop

* remove depedency from other PR

* revert pbtxt

* remove placeholders from matmul_v2

* add description in OPMaker

* remove matmul_v2_op.h and all depedencies

* remove dims changing in base op

* add possibility to fuse already fused_matmul

* restart broken CI

* Empty-Commit

* revert matmul_utils.h

* codestyle

* adjust imports

* add pbtxt file

* 100% matmul unit tests coverage

* trigger CI with minimal changes to develop

* adjust changes to develop

* add fused_matmul op

* inherit base ops

* add "v2"

* move OPMaker

* Gradually add fused_matmul files

* second batch of fused_matmul changes

* split infershapes of matmul_v2 and fused_matmul

* inherit fused_matmul from matmul_v2

* Update paddle/phi/backends/onednn/onednn_reuse.h
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

* Update paddle/phi/kernels/fusion/onednn/fused_matmul_kernel.cc
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

---------
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>
上级 aa8cef4a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -26,7 +26,7 @@ using string::PrettyLogDetail; ...@@ -26,7 +26,7 @@ using string::PrettyLogDetail;
void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = GetSupportedActivations(); auto act_types = GetSupportedActivations();
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"fused_matmul", "matmul", "matmul_v2"};
for (const auto& matmul_type : matmul_types) for (const auto& matmul_type : matmul_types)
for (auto& act_type : act_types) { for (auto& act_type : act_types) {
...@@ -61,8 +61,17 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct( ...@@ -61,8 +61,17 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, matmul_act_pattern); activation_out, activation_out, matmul_act_pattern);
SetActivationAttrs(matmul->Op(), activation->Op(), act_type); OpDesc* matmul_op = matmul->Op();
matmul->Op()->SetOutput("Out", {activation_out->Name()});
matmul_op->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_op->SetAttr("trans_x", matmul_op->GetAttr("transpose_X"));
matmul_op->SetAttr("trans_y", matmul_op->GetAttr("transpose_Y"));
matmul_op->SetAttr("matmul_alpha", matmul_op->GetAttr("alpha"));
}
SetActivationAttrs(matmul_op, activation->Op(), act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});
IR_NODE_LINK_TO(matmul, activation_out); IR_NODE_LINK_TO(matmul, activation_out);
GraphSafeRemoveNodes(graph, {activation, matmul_out}); GraphSafeRemoveNodes(graph, {activation, matmul_out});
...@@ -88,11 +97,6 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { ...@@ -88,11 +97,6 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.AddInput("Y") .AddInput("Y")
.IsTensor() .IsTensor()
.End() .End()
.AddInput(
"ResidualData") // Extra tensor used in matmul+elementwise_add fuse
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End() .End()
...@@ -113,8 +117,24 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { ...@@ -113,8 +117,24 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.AddInput("Y") .AddInput("Y")
.IsTensor() .IsTensor()
.End() .End()
.AddInput( .AddOutput("Out")
"ResidualData") // Extra tensor used in matmul+elementwise_add fuse .IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("fused_matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor() .IsTensor()
.IsOptional() .IsOptional()
.End() .End()
...@@ -126,6 +146,50 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { ...@@ -126,6 +146,50 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.End() .End()
.AddAttr("trans_y") .AddAttr("trans_y")
.IsType<bool>() .IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End(); .End();
AddOpCompat(OpCompat("abs")) AddOpCompat(OpCompat("abs"))
...@@ -279,6 +343,7 @@ REGISTER_PASS(matmul_activation_mkldnn_fuse_pass, ...@@ -279,6 +343,7 @@ REGISTER_PASS(matmul_activation_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fused_matmul", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.EQ("abs", 0) .EQ("abs", 0)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void MatmulElementwiseAddMKLDNNFusePass::ApplyImpl(Graph* graph) const { void MatmulElementwiseAddMKLDNNFusePass::ApplyImpl(Graph* graph) const {
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"fused_matmul", "matmul", "matmul_v2"};
auto matmul_as_x = {true, false}; auto matmul_as_x = {true, false};
for (const auto& matmul_type : matmul_types) for (const auto& matmul_type : matmul_types)
...@@ -65,6 +65,12 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd( ...@@ -65,6 +65,12 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd(
return; return;
} }
matmul->Op()->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul->Op()->SetAttr("trans_x", matmul->Op()->GetAttr("transpose_X"));
matmul->Op()->SetAttr("trans_y", matmul->Op()->GetAttr("transpose_Y"));
matmul->Op()->SetAttr("matmul_alpha", matmul->Op()->GetAttr("alpha"));
}
matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()}); matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()});
matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()}); matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()});
...@@ -125,6 +131,71 @@ MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() { ...@@ -125,6 +131,71 @@ MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() {
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("fused_matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -149,6 +220,7 @@ REGISTER_PASS(matmul_elementwise_add_mkldnn_fuse_pass, ...@@ -149,6 +220,7 @@ REGISTER_PASS(matmul_elementwise_add_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(matmul_elementwise_add_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(matmul_elementwise_add_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fused_matmul", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.LE("elementwise_add", 1)); .LE("elementwise_add", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -24,7 +24,7 @@ namespace ir { ...@@ -24,7 +24,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(Graph *graph) const { void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(Graph *graph) const {
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"fused_matmul", "matmul", "matmul_v2"};
for (const auto &matmul_type : matmul_types) { for (const auto &matmul_type : matmul_types) {
Fuse(graph, matmul_type); Fuse(graph, matmul_type);
...@@ -84,6 +84,12 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse( ...@@ -84,6 +84,12 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse(
} }
OpDesc *matmul_desc = matmul_op->Op(); OpDesc *matmul_desc = matmul_op->Op();
matmul_desc->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
matmul_desc->SetAttr("matmul_alpha", matmul_desc->GetAttr("alpha"));
}
matmul_desc->SetOutput("Out", {reshape_out->Name()}); matmul_desc->SetOutput("Out", {reshape_out->Name()});
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape); matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
matmul_desc->SetAttr("fused_transpose_Out", transpose_axis); matmul_desc->SetAttr("fused_transpose_Out", transpose_axis);
...@@ -149,6 +155,71 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { ...@@ -149,6 +155,71 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("fused_matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End();
AddOpCompat(OpCompat("transpose2")) AddOpCompat(OpCompat("transpose2"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -189,6 +260,7 @@ REGISTER_PASS(matmul_transpose_reshape_mkldnn_fuse_pass, ...@@ -189,6 +260,7 @@ REGISTER_PASS(matmul_transpose_reshape_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fused_matmul", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.EQ("transpose2", 0) .EQ("transpose2", 0)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -27,6 +27,7 @@ using string::PrettyLogDetail; ...@@ -27,6 +27,7 @@ using string::PrettyLogDetail;
void FuseOperatorScaleOneDNNPass::ApplyImpl(Graph *graph) const { void FuseOperatorScaleOneDNNPass::ApplyImpl(Graph *graph) const {
const std::vector<std::string> fusable_ops{ const std::vector<std::string> fusable_ops{
"fc", "fc",
"fused_matmul",
"matmul", "matmul",
"matmul_v2", "matmul_v2",
"elementwise_add", "elementwise_add",
...@@ -85,6 +86,19 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph, ...@@ -85,6 +86,19 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph,
scale = *(scale_tensor->data<float>()); scale = *(scale_tensor->data<float>());
} }
if (op_type == "matmul") {
operator_op->Op()->SetType("fused_matmul");
operator_op->Op()->SetAttr("trans_x",
operator_op->Op()->GetAttr("transpose_X"));
operator_op->Op()->SetAttr("trans_y",
operator_op->Op()->GetAttr("transpose_Y"));
operator_op->Op()->SetAttr("matmul_alpha",
operator_op->Op()->GetAttr("alpha"));
}
if (op_type == "matmul_v2") {
operator_op->Op()->SetType("fused_matmul");
}
operator_op->Op()->SetAttr("fused_output_scale", scale); operator_op->Op()->SetAttr("fused_output_scale", scale);
operator_op->Op()->SetOutput("Out", {scale_out->Name()}); operator_op->Op()->SetOutput("Out", {scale_out->Name()});
...@@ -111,6 +125,7 @@ REGISTER_PASS_CAPABILITY(operator_scale_onednn_fuse_pass) ...@@ -111,6 +125,7 @@ REGISTER_PASS_CAPABILITY(operator_scale_onednn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc", 0) .EQ("fc", 0)
.EQ("fused_matmul", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.LE("elementwise_add", 1) .LE("elementwise_add", 1)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -22,7 +22,7 @@ namespace framework { ...@@ -22,7 +22,7 @@ namespace framework {
namespace ir { namespace ir {
void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const { void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const {
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"matmul", "matmul_v2", "fused_matmul"};
for (const auto &matmul_type : matmul_types) { for (const auto &matmul_type : matmul_types) {
Fuse(graph, Fuse(graph,
...@@ -102,6 +102,25 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -102,6 +102,25 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
matmul_type + " encountered."); matmul_type + " encountered.");
} }
// Return if input of fused_matmul is already fused
if (matmul_type == "fused_matmul") {
auto is_already_fused_X =
matmul_desc->HasAttr("fused_reshape_X")
? !(PADDLE_GET_CONST(std::vector<int>,
matmul_desc->GetAttr("fused_reshape_X"))
.empty())
: false;
if (is_already_fused_X && matmul_input_name == "X") return;
auto is_already_fused_Y =
matmul_desc->HasAttr("fused_reshape_Y")
? !(PADDLE_GET_CONST(std::vector<int>,
matmul_desc->GetAttr("fused_reshape_Y"))
.empty())
: false;
if (is_already_fused_Y && matmul_input_name == "Y") return;
}
auto reshape_shape = auto reshape_shape =
paddle::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape")); paddle::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
auto transpose_axis = auto transpose_axis =
...@@ -123,6 +142,12 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -123,6 +142,12 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
return; return;
} }
matmul_desc->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
matmul_desc->SetAttr("matmul_alpha", matmul_desc->GetAttr("alpha"));
}
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()}); matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape); matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name, matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
...@@ -220,6 +245,71 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { ...@@ -220,6 +245,71 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
.AddAttr("trans_y") .AddAttr("trans_y")
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("fused_matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End();
} }
} // namespace ir } // namespace ir
...@@ -234,5 +324,6 @@ REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_mkldnn_fuse_pass) ...@@ -234,5 +324,6 @@ REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_mkldnn_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0) .EQ("reshape2", 0)
.EQ("transpose2", 0) .EQ("transpose2", 0)
.EQ("fused_matmul", 0)
.EQ("matmul", 1) .EQ("matmul", 1)
.EQ("matmul_v2", 0)); .EQ("matmul_v2", 0));
type: "fused_matmul"
def {
inputs {
name: "X"
}
inputs {
name: "Y"
}
inputs {
name: "ResidualData"
}
outputs {
name: "Out"
}
attrs {
name: "trans_x"
type: BOOLEAN
}
attrs {
name: "trans_y"
type: BOOLEAN
}
}
extra {
attrs {
name: "matmul_alpha"
type: FLOAT
}
attrs {
name: "fuse_activation"
type: STRING
}
attrs {
name: "fuse_alpha"
type: FLOAT
}
attrs {
name: "fuse_beta"
type: FLOAT
}
attrs {
name: "fused_output_scale"
type: FLOAT
}
attrs {
name: "fused_reshape_X"
type: INTS
}
attrs {
name: "fused_transpose_X"
type: INTS
}
attrs {
name: "fused_reshape_Y"
type: INTS
}
attrs {
name: "fused_transpose_Y"
type: INTS
}
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "Scale_x"
type: FLOAT
}
attrs {
name: "Scale_y"
type: FLOAT
}
attrs {
name: "Scale_in_eltwise"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "force_fp32_output"
type: BOOLEAN
}
}
...@@ -39,28 +39,4 @@ extra { ...@@ -39,28 +39,4 @@ extra {
name: "op_device" name: "op_device"
type: STRING type: STRING
} }
attrs {
name: "fused_reshape_X"
type: INTS
}
attrs {
name: "fused_reshape_Y"
type: INTS
}
attrs {
name: "fused_transpose_X"
type: INTS
}
attrs {
name: "fused_transpose_Y"
type: INTS
}
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
} }
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/matmul_v2_op.h"
namespace paddle {
namespace operators {
static std::vector<int64_t> GetInputShape(phi::DDim dim,
std::vector<int> shape,
std::vector<int> axis) {
PADDLE_ENFORCE_GT(dim.size(),
0,
phi::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));
auto is_input_fused = (!shape.empty() && !axis.empty());
if (is_input_fused) {
dim = dim.reshape(shape).transpose(axis);
}
return phi::vectorize(dim);
}
class FusedMatmulOp : public MatMulV2Op {
public:
using MatMulV2Op::MatMulV2Op;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fused_matmul");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fused_matmul");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fused_matmul");
bool trans_x = ctx->Attrs().Get<bool>("trans_x");
bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x =
GetInputShape(ctx->GetInputDim("X"),
ctx->Attrs().Get<std::vector<int>>("fused_reshape_X"),
ctx->Attrs().Get<std::vector<int>>("fused_transpose_X"));
std::vector<int64_t> dims_y =
GetInputShape(ctx->GetInputDim("Y"),
ctx->Attrs().Get<std::vector<int>>("fused_reshape_Y"),
ctx->Attrs().Get<std::vector<int>>("fused_transpose_Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0,
phi::errors::InvalidArgument(
"The Input(X) dims size must be greater than 0,"
" but received dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0,
phi::errors::InvalidArgument(
"The Input(Y) dims size must be greater than 0,"
" but received dims size is 0. "));
bool x_broadcasted = false;
bool y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto ddim_out = phi::make_ddim(new_dims);
auto shape = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
auto is_output_fused = (!shape.empty() && !axis.empty());
if (is_output_fused) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
ctx->SetOutputDim("Out", ddim_out);
ctx->ShareLoD("X", "Out");
}
};
class FusedMatmulOpMaker : public MatMulV2OpMaker {
protected:
void Apply() override {
AddInput("ResidualData",
"Extra input from matmul_elementwise_add_mkldnn_fuse_pass")
.AsDispensable()
.AsExtra();
AddAttr<float>("matmul_alpha", "Output scale used in matmul_v1")
.SetDefault(1.0f);
AddAttr<std::string>(
"fuse_activation",
"Activation type from matmul_activation_mkldnn_fuse_pass")
.SetDefault("");
AddAttr<float>("fuse_alpha",
"Activation alpha from matmul_activation_mkldnn_fuse_pass")
.SetDefault(0.0f);
AddAttr<float>("fuse_beta",
"Activation beta from matmul_activation_mkldnn_fuse_pass")
.SetDefault(0.0f);
AddAttr<float>("fused_output_scale",
"Output scale from operator_scale_onednn_fuse_pass")
.SetDefault(1.0f);
AddAttr<std::vector<int>>("fused_reshape_X",
"Reshape's shape attribute from "
"reshape_transpose_matmul_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_transpose_X",
"Transpose's axis attribute from "
"reshape_transpose_matmul_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_reshape_Y",
"Reshape's shape attribute from "
"reshape_transpose_matmul_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_transpose_Y",
"Transpose's axis attribute from "
"reshape_transpose_matmul_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_reshape_Out",
"Reshape's shape attribute from "
"matmul_transpose_reshape_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_transpose_Out",
"Transpose's axis attribute from "
"matmul_transpose_reshape_mkldnn_fuse_pass")
.SetDefault({});
AddAttr<std::string>("mkldnn_data_type", "oneDNN operator data type")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"});
AddAttr<float>("Scale_x", "Matmul X input quantization scale")
.SetDefault(1.0f);
AddAttr<float>("Scale_y", "Matmul Y input quantization scale")
.SetDefault(1.0f);
AddAttr<float>("Scale_in_eltwise", "Matmul ResidualData quantization scale")
.SetDefault(0.0f);
AddAttr<float>("Scale_out", "Matmul output quantization scale")
.SetDefault(1.0f);
AddAttr<bool>("force_fp32_output",
"Flag determining if output should be converted to FP32")
.SetDefault(false);
AddComment(
R"DOC(Matrix multiplication extended with oneDNN-specific fusion logic.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_matmul,
ops::FusedMatmulOp,
ops::FusedMatmulOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -24,52 +24,31 @@ ...@@ -24,52 +24,31 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx, void MatMulV2Op::InferShape(framework::InferShapeContext* ctx) const {
const std::string input_name) {
auto shape = ctx.Attrs().Get<std::vector<int>>("fused_reshape_" + input_name);
auto axis =
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.GetInputDim(input_name);
PADDLE_ENFORCE_GT(dim.size(),
0,
platform::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));
if (!shape.empty() && !axis.empty()) {
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}
class MatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matmul_v2"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "matmul_v2"); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matmul_v2"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matmul_v2");
bool trans_x = ctx->Attrs().Get<bool>("trans_x"); bool trans_x = ctx->Attrs().Get<bool>("trans_x");
bool trans_y = ctx->Attrs().Get<bool>("trans_y"); bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x = phi::vectorize(GetDimForInput(*ctx, "X")); std::vector<int64_t> dims_x = phi::vectorize(ctx->GetInputDim("X"));
std::vector<int64_t> dims_y = phi::vectorize(GetDimForInput(*ctx, "Y")); std::vector<int64_t> dims_y = phi::vectorize(ctx->GetInputDim("Y"));
auto ndims_x = dims_x.size(); auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size(); auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x, PADDLE_ENFORCE_GT(ndims_x,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The Input(X) dims size must be greater than 0," "The Input(X) dims size must be greater than 0,"
" but received dims size is 0. ")); " but received dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y, PADDLE_ENFORCE_GT(ndims_y,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The Input(Y) dims size must be greater than 0," "The Input(Y) dims size must be greater than 0,"
" but received dims size is 0. ")); " but received dims size is 0. "));
bool x_broadcasted = false, y_broadcasted = false; bool x_broadcasted = false;
bool y_broadcasted = false;
if (ndims_x == 1) { if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1); dims_x.insert(dims_x.begin(), 1);
ndims_x = 2; ndims_x = 2;
...@@ -115,33 +94,21 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -115,33 +94,21 @@ class MatMulV2Op : public framework::OperatorWithKernel {
new_dims.push_back(1); new_dims.push_back(1);
} }
auto ddim_out = phi::make_ddim(new_dims); ctx->SetOutputDim("Out", phi::make_ddim(new_dims));
#ifdef PADDLE_WITH_MKLDNN
auto shape = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!shape.empty() && !axis.empty()) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
#endif
ctx->SetOutputDim("Out", ddim_out);
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
protected: phi::KernelKey MatMulV2Op::GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const {
const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
return phi::KernelKey(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
phi::KernelKey GetKernelTypeForVar( phi::KernelKey MatMulV2Op::GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override { const phi::KernelKey& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.dtype())) { if (framework::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input // only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
...@@ -153,20 +120,16 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -153,20 +120,16 @@ class MatMulV2Op : public framework::OperatorWithKernel {
(tensor.layout() != phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) &&
phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) { phi::DataLayout::kNHWC) {
return phi::KernelKey(tensor.place(), return phi::KernelKey(
phi::DataLayout::kNHWC, tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype());
expected_kernel_type.dtype());
} }
#endif #endif
return phi::KernelKey( return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
} }
};
class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { void MatMulV2OpMaker::Make() {
public:
void Make() override {
AddInput("X", "tensor of shape (d0, d1 ... M, K)"); AddInput("X", "tensor of shape (d0, d1 ... M, K)");
AddInput("Y", "tensor of shape (d0, d1 ... K, N)"); AddInput("Y", "tensor of shape (d0, d1 ... K, N)");
AddOutput("Out", "tensor of shape (d0, d1 ... M, N)"); AddOutput("Out", "tensor of shape (d0, d1 ... M, N)");
...@@ -184,8 +147,8 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -184,8 +147,8 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
In addition, it also follows the broadcast rule which is similar as In addition, it also follows the broadcast rule which is similar as
numpy.matmul. numpy.matmul.
)DOC"); )DOC");
} Apply();
}; }
class MatMulV2OpGrad : public framework::OperatorWithKernel { class MatMulV2OpGrad : public framework::OperatorWithKernel {
public: public:
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -37,6 +37,29 @@ limitations under the License. */ ...@@ -37,6 +37,29 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override;
};
class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final;
protected:
virtual void Apply() {}
};
// Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
static phi::DenseTensor FoldInitDims(const phi::DenseTensor& input) { static phi::DenseTensor FoldInitDims(const phi::DenseTensor& input) {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -102,12 +102,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet> ...@@ -102,12 +102,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"fused_output_scale", ExtraAttrProperty::ONEDNN}, {"fused_output_scale", ExtraAttrProperty::ONEDNN},
{"fuse_residual_connection", ExtraAttrProperty::ONEDNN}, {"fuse_residual_connection", ExtraAttrProperty::ONEDNN},
{"fuse_with_relu", ExtraAttrProperty::ONEDNN}, {"fuse_with_relu", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Out", ExtraAttrProperty::ONEDNN},
{"fused_transpose_Out", ExtraAttrProperty::ONEDNN},
{"fused_reshape_X", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Y", ExtraAttrProperty::ONEDNN},
{"fused_transpose_X", ExtraAttrProperty::ONEDNN},
{"fused_transpose_Y", ExtraAttrProperty::ONEDNN},
{"mkldnn_data_type", ExtraAttrProperty::ONEDNN}, {"mkldnn_data_type", ExtraAttrProperty::ONEDNN},
{"scale_x", ExtraAttrProperty::ONEDNN}, {"scale_x", ExtraAttrProperty::ONEDNN},
{"scale_y", ExtraAttrProperty::ONEDNN}, {"scale_y", ExtraAttrProperty::ONEDNN},
...@@ -226,8 +220,7 @@ class ExtraInfoUtils { ...@@ -226,8 +220,7 @@ class ExtraInfoUtils {
std::unordered_map<std::string, std::vector<std::string>> std::unordered_map<std::string, std::vector<std::string>>
g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}}, g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}},
{"conv2d_transpose", {"Bias"}}, {"conv2d_transpose", {"Bias"}},
{"conv2d_grad", {"Bias"}}, {"conv2d_grad", {"Bias"}}};
{"matmul_v2", {"ResidualData"}}};
std::vector<std::string> empty_extra_input_names_; std::vector<std::string> empty_extra_input_names_;
}; };
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -53,26 +53,31 @@ constexpr bool is_bfloat16() { ...@@ -53,26 +53,31 @@ constexpr bool is_bfloat16() {
static void AppendActivation(const OneDNNContext& dev_ctx, static void AppendActivation(const OneDNNContext& dev_ctx,
dnnl::post_ops& post_ops, // NOLINT dnnl::post_ops& post_ops, // NOLINT
float activation_scale = 1.0f) { float activation_scale = 1.0f,
std::string fuse_activation = "",
float fuse_alpha = 0.0f,
float fuse_beta = 0.0f) {
if (fuse_activation == "") {
const auto invalid_attribute = const auto invalid_attribute =
dev_ctx.HasDnnAttr("fuse_activation") dev_ctx.HasDnnAttr("fuse_activation")
? PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation")) ? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("fuse_activation"))
.empty() .empty()
: true; : true;
if (invalid_attribute) return; if (invalid_attribute) return;
const auto fuse_activation = fuse_activation =
dev_ctx.HasDnnAttr("fuse_activation") dev_ctx.HasDnnAttr("fuse_activation")
? PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation")) ? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("fuse_activation"))
: ""; : "";
const auto fuse_alpha = fuse_alpha = dev_ctx.HasDnnAttr("fuse_alpha")
dev_ctx.HasDnnAttr("fuse_alpha")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_alpha")) ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_alpha"))
: 0.0f; : 0.0f;
const auto fuse_beta = fuse_beta = dev_ctx.HasDnnAttr("fuse_beta")
dev_ctx.HasDnnAttr("fuse_beta")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_beta")) ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_beta"))
: 0.0f; : 0.0f;
}
if (fuse_activation == "hard_sigmoid") { if (fuse_activation == "hard_sigmoid") {
post_ops.append_eltwise(activation_scale, post_ops.append_eltwise(activation_scale,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
using dnnl::engine;
using dnnl::inner_product_forward;
using dnnl::memory;
using dnnl::prop_kind;
using dnnl::stream;
using paddle::framework::ReshapeToMatrix;
namespace phi {
template <typename XT, typename YT, typename OT>
class FusedMatmulOneDNNHandler
: public funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
FusedMatmulOneDNNHandler(const OneDNNContext &dev_ctx,
const DenseTensor *residual_data,
const std::vector<int64_t> &x_org_dims,
const std::vector<int64_t> &y_org_dims,
bool trans_x,
bool trans_y,
const float matmul_alpha,
const std::vector<int64_t> &x_strides_override,
const std::vector<int64_t> &y_strides_override,
bool is_output_fused,
const std::string &fuse_activation,
const float fuse_alpha,
const float fuse_beta,
const float fused_output_scale,
const float scale_x,
const float scale_y,
const float scale_in_eltwise,
const float scale_out,
const bool force_fp32_output)
: funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(dev_ctx.GetEngine(),
dev_ctx.GetPlace()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];
std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);
x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
if (x_strides_override.empty()) {
if (trans_x) {
x_strides.insert(x_strides.end(), {M * K, 1, M});
} else {
x_strides.insert(x_strides.end(), {M * K, K, 1});
}
} else {
x_strides = x_strides_override;
}
if (y_strides_override.empty()) {
if (trans_y) {
y_strides.insert(y_strides.end(), {N * K, 1, K});
} else {
y_strides.insert(y_strides.end(), {N * K, N, 1});
}
} else {
y_strides = y_strides_override;
}
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) {
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
// TODO(jczaja): Why not for int8??
if (!funcs::is_int8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, funcs::OneDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, funcs::OneDNNGetDataType<YT>(), y_strides);
auto out_md =
memory::desc(out_ddims, funcs::OneDNNGetDataType<OT>(), out_strides);
const auto matmul_attrs = CreateMatmulAttrs(dev_ctx,
residual_data,
matmul_alpha,
fuse_activation,
fuse_alpha,
fuse_beta,
fused_output_scale,
scale_x,
scale_y,
scale_in_eltwise,
scale_out,
force_fp32_output);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
float ComputeOutputScale(float matmul_alpha,
const float scale_x,
const float scale_y,
const float scale_in_eltwise,
const float scale_out,
const bool force_fp32_output) {
float f_scale_out = force_fp32_output ? 1.0f : scale_out;
matmul_alpha *= f_scale_out / (scale_x * scale_y);
return matmul_alpha;
}
dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext &dev_ctx,
const DenseTensor *residual_data,
const float matmul_alpha,
const std::string &fuse_activation,
const float fuse_alpha,
const float fuse_beta,
const float fused_output_scale,
const float scale_x,
const float scale_y,
const float scale_in_eltwise,
const float scale_out,
const bool force_fp32_output) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float computed_scale_out = ComputeOutputScale(matmul_alpha,
scale_x,
scale_y,
scale_in_eltwise,
scale_out,
force_fp32_output);
if (computed_scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {computed_scale_out});
}
if (residual_data) {
auto residual_data_tz = vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
funcs::OneDNNGetDataType<OT>(),
dnnl::memory::format_tag::any);
post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md);
if (scale_in_eltwise != 0.0f) {
float sum_scale = scale_out / scale_in_eltwise;
post_operations.append_sum(sum_scale);
}
}
funcs::AppendActivation(
dev_ctx, post_operations, 1.0f, fuse_activation, fuse_alpha, fuse_beta);
if (fused_output_scale != 1.0f) {
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, fused_output_scale, 0.0f);
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t> &matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const DenseTensor *input) {
const YT *input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(), funcs::to_void_cast<YT>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(const OneDNNContext &dev_ctx,
DenseTensor *output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence DenseTensor size is bigger that
// dst memory primitive size. So would we request less memory that is there
// and it triggers an assertion. So as there is no 'any' format here we can
// leave default size of DenseTensor as computed in ComputeInferShape
OT *ptr = dev_ctx.template Alloc<OT>(output);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
};
static DDim RowMatrixDimsFromVector(const DDim &x_dim) {
return x_dim.size() > 1 ? x_dim : make_ddim({1, x_dim[0]});
}
static DDim ColumnMatrixDimsFromVector(const DDim &y_dim) {
return y_dim.size() > 1 ? y_dim : make_ddim({y_dim[0], 1});
}
static std::vector<int64_t> TransposeAxis(const std::vector<int64_t> &x,
const std::vector<int> &axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
phi::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(
in_rank,
axis_size,
phi::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
phi::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
static std::vector<int64_t> GetInputStrides(const std::string input_name,
const DDim &input_dims,
std::vector<int> shape,
std::vector<int> axis,
const bool transpose_input) {
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto &MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
funcs::MatDescriptor mat_dim = funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0, transpose_input);
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = TransposeAxis(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
template <typename T, typename T_out>
void ExecuteFusedMatmul(const OneDNNContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor *residual_data,
const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims,
bool trans_x,
bool trans_y,
const float matmul_alpha,
const std::vector<int64_t> &x_strides_override,
const std::vector<int64_t> &y_strides_override,
const bool is_output_fused,
const std::vector<int> &fused_transpose_Out,
const std::string &fuse_activation,
const float fuse_alpha,
const float fuse_beta,
const float fused_output_scale,
const float scale_x,
const float scale_y,
const float scale_in_eltwise,
const float scale_out,
const bool force_fp32_output,
DenseTensor *out) {
FusedMatmulOneDNNHandler<T, T, T_out> handler(dev_ctx,
residual_data,
x_dims,
y_dims,
trans_x,
trans_y,
matmul_alpha,
x_strides_override,
y_strides_override,
is_output_fused,
fuse_activation,
fuse_alpha,
fuse_beta,
fused_output_scale,
scale_x,
scale_y,
scale_in_eltwise,
scale_out,
force_fp32_output);
const auto src_memory_p = handler.AcquireSrcMemory(&x);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y);
const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (residual_data) {
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p});
}
auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
if (is_output_fused && !funcs::is_int8<T_out>()) {
auto permuted_md =
dst_memory_p->get_desc().permute_axes(fused_transpose_Out);
out->set_mem_desc(permuted_md.reshape(vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
}
std::vector<int64_t> GetInputShape(DDim input_dims,
std::vector<int> shape,
std::vector<int> axis) {
if (!shape.empty() && !axis.empty()) {
return vectorize(input_dims.reshape(shape).transpose(axis));
}
return vectorize(input_dims);
}
void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims,
DenseTensor *out,
const bool is_output_fused) {
if (x_dims.size() == 1) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1];
(*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
(*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
(*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
(*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1];
(*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
(*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i];
}
}
if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
(*y_bd_dims)[i] == 1,
true,
errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i,
(*x_bd_dims)[i],
i,
(*y_bd_dims)[i]));
(out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
}
out->Resize(make_ddim((out_dims)));
}
}
template <typename T, typename Context>
void FusedMatmulKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
const paddle::optional<DenseTensor> &residual_data,
bool transpose_x,
bool transpose_y,
const float matmul_alpha,
const std::string &fuse_activation,
const float fuse_alpha,
const float fuse_beta,
const float fused_output_scale,
const std::vector<int> &fused_reshape_X,
const std::vector<int> &fused_transpose_X,
const std::vector<int> &fused_reshape_Y,
const std::vector<int> &fused_transpose_Y,
const std::vector<int> &fused_reshape_Out,
const std::vector<int> &fused_transpose_Out,
const std::string &mkldnn_data_type,
const float scale_x,
const float scale_y,
const float scale_in_eltwise,
const float scale_out,
const bool force_fp32_output,
DenseTensor *out) {
if (dev_ctx.HasDnnAttr("head_number")) {
const auto head_number =
PADDLE_GET_CONST(int, dev_ctx.GetDnnAttr("head_number"));
PADDLE_ENFORCE_EQ(
head_number,
1,
errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
head_number));
}
constexpr bool is_int8 = funcs::is_int8<T>();
constexpr bool is_bfloat16 = funcs::is_bfloat16<T>();
bool fuse_relu = false;
if (fuse_activation == "relu" || fuse_activation == "relu6") {
fuse_relu = true;
}
auto x_dims = GetInputShape(x.dims(), fused_reshape_X, fused_transpose_X);
auto y_dims = GetInputShape(y.dims(), fused_reshape_Y, fused_transpose_Y);
auto is_output_fused =
!fused_reshape_Out.empty() && !fused_transpose_Out.empty();
auto x_strides_override = GetInputStrides(
"X", x.dims(), fused_reshape_X, fused_transpose_X, transpose_x);
auto y_strides_override = GetInputStrides(
"Y", y.dims(), fused_reshape_Y, fused_transpose_Y, transpose_y);
int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3);
std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(
x_dims, y_dims, &x_bd_dims, &y_bd_dims, out, is_output_fused);
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
ExecuteFusedMatmul<T, float>(dev_ctx,
x,
y,
residual_data.get_ptr(),
x_bd_dims,
y_bd_dims,
transpose_x,
transpose_y,
matmul_alpha,
x_strides_override,
y_strides_override,
is_output_fused,
fused_transpose_Out,
fuse_activation,
fuse_alpha,
fuse_beta,
fused_output_scale,
scale_x,
scale_y,
scale_in_eltwise,
scale_out,
force_fp32_output,
out);
} else if (is_bfloat16) {
ExecuteFusedMatmul<T, phi::dtype::bfloat16>(dev_ctx,
x,
y,
residual_data.get_ptr(),
x_bd_dims,
y_bd_dims,
transpose_x,
transpose_y,
matmul_alpha,
x_strides_override,
y_strides_override,
is_output_fused,
fused_transpose_Out,
fuse_activation,
fuse_alpha,
fuse_beta,
fused_output_scale,
scale_x,
scale_y,
scale_in_eltwise,
scale_out,
force_fp32_output,
out);
} else if (fuse_relu) {
ExecuteFusedMatmul<T, uint8_t>(dev_ctx,
x,
y,
residual_data.get_ptr(),
x_bd_dims,
y_bd_dims,
transpose_x,
transpose_y,
matmul_alpha,
x_strides_override,
y_strides_override,
is_output_fused,
fused_transpose_Out,
fuse_activation,
fuse_alpha,
fuse_beta,
fused_output_scale,
scale_x,
scale_y,
scale_in_eltwise,
scale_out,
force_fp32_output,
out);
} else {
ExecuteFusedMatmul<T, int8_t>(dev_ctx,
x,
y,
residual_data.get_ptr(),
x_bd_dims,
y_bd_dims,
transpose_x,
transpose_y,
matmul_alpha,
x_strides_override,
y_strides_override,
is_output_fused,
fused_transpose_Out,
fuse_activation,
fuse_alpha,
fuse_beta,
fused_output_scale,
scale_x,
scale_y,
scale_in_eltwise,
scale_out,
force_fp32_output,
out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(fused_matmul,
OneDNN,
ONEDNN,
phi::FusedMatmulKernel,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -51,8 +51,7 @@ void CalculateMatrixDims(const std::vector<int64_t> &x_dims, ...@@ -51,8 +51,7 @@ void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims, const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims, std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims, std::vector<int64_t> *y_bd_dims,
DenseTensor *out, DenseTensor *out) {
const bool is_output_fused) {
if (x_dims.size() == 1) { if (x_dims.size() == 1) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0]; (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) { } else if (x_dims.size() == 2) {
...@@ -74,7 +73,7 @@ void CalculateMatrixDims(const std::vector<int64_t> &x_dims, ...@@ -74,7 +73,7 @@ void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
} }
} }
if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) { if (x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims()); auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) { for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -121,15 +120,6 @@ void MatmulKernel(const Context &dev_ctx, ...@@ -121,15 +120,6 @@ void MatmulKernel(const Context &dev_ctx,
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false; : false;
bool fuse_relu = false;
if (dev_ctx.HasDnnAttr("fuse_activation")) {
auto act_type =
PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"));
if (act_type == "relu" || act_type == "relu6") {
fuse_relu = true;
}
}
auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X")); auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X"));
auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y")); auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y"));
...@@ -139,12 +129,7 @@ void MatmulKernel(const Context &dev_ctx, ...@@ -139,12 +129,7 @@ void MatmulKernel(const Context &dev_ctx,
std::vector<int64_t> x_bd_dims(ndims, 1); std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1); std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(x_dims, CalculateMatrixDims(x_dims, y_dims, &x_bd_dims, &y_bd_dims, out);
y_dims,
&x_bd_dims,
&y_bd_dims,
out,
funcs::IsOutputFused(dev_ctx));
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
funcs::ExecuteMatmul<T, float>( funcs::ExecuteMatmul<T, float>(
...@@ -152,9 +137,6 @@ void MatmulKernel(const Context &dev_ctx, ...@@ -152,9 +137,6 @@ void MatmulKernel(const Context &dev_ctx,
} else if (is_bfloat16) { } else if (is_bfloat16) {
funcs::ExecuteMatmul<T, paddle::platform::bfloat16>( funcs::ExecuteMatmul<T, paddle::platform::bfloat16>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out); dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else if (fuse_relu) {
funcs::ExecuteMatmul<T, uint8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else { } else {
funcs::ExecuteMatmul<T, int8_t>( funcs::ExecuteMatmul<T, int8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out); dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FusedMatmulOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fused_matmul",
{"X", "Y", "ResidualData"},
{"trans_x",
"trans_y",
"matmul_alpha",
"fuse_activation",
"fuse_alpha",
"fuse_beta",
"fused_output_scale",
"fused_reshape_X",
"fused_transpose_X",
"fused_reshape_Y",
"fused_transpose_Y",
"fused_reshape_Out",
"fused_transpose_Out",
"mkldnn_data_type",
"Scale_x",
"Scale_y",
"Scale_in_eltwise",
"Scale_out",
"force_fp32_output"},
{"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fused_matmul, phi::FusedMatmulOpArgumentMapping);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -103,12 +103,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest): ...@@ -103,12 +103,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
trans_x=transpose_X, trans_x=transpose_X,
trans_y=transpose_Y, trans_y=transpose_Y,
fused_reshape_Out=[],
fused_transpose_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
) )
ops = [ ops = [
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -92,12 +92,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest): ...@@ -92,12 +92,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
trans_x=transpose_X, trans_x=transpose_X,
trans_y=transpose_Y, trans_y=transpose_Y,
fused_reshape_Out=[],
fused_transpose_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
) )
ops = [ ops = [
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -76,12 +76,6 @@ class TestMatmulV2ScaleFusePass(PassAutoScanTest): ...@@ -76,12 +76,6 @@ class TestMatmulV2ScaleFusePass(PassAutoScanTest):
outputs={"Out": ["matmul_out"]}, outputs={"Out": ["matmul_out"]},
trans_x=transpose_X, trans_x=transpose_X,
trans_y=transpose_Y, trans_y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
) )
is_scale_tensor = draw(st.booleans()) is_scale_tensor = draw(st.booleans())
if is_scale_tensor: if is_scale_tensor:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -146,7 +146,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): ...@@ -146,7 +146,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
'operator_scale_onednn_fuse_pass', 'operator_scale_onednn_fuse_pass',
], ],
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -137,7 +137,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest): ...@@ -137,7 +137,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest):
'matmul_activation_mkldnn_fuse_pass', 'matmul_activation_mkldnn_fuse_pass',
], ],
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -76,7 +76,7 @@ class TestMatmulElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -76,7 +76,7 @@ class TestMatmulElementwiseAddMkldnnFusePass(PassAutoScanTest):
config = self.create_inference_config( config = self.create_inference_config(
use_mkldnn=True, passes=['matmul_elementwise_add_mkldnn_fuse_pass'] use_mkldnn=True, passes=['matmul_elementwise_add_mkldnn_fuse_pass']
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -150,7 +150,7 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): ...@@ -150,7 +150,7 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
'operator_scale_onednn_fuse_pass', 'operator_scale_onednn_fuse_pass',
], ],
) )
yield config, ['matmul_v2'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -87,7 +87,7 @@ class TestMatmulV2ElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -87,7 +87,7 @@ class TestMatmulV2ElementwiseAddMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul_v2'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -90,12 +90,6 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -90,12 +90,6 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
attrs={ attrs={
"trans_x": transpose_X, "trans_x": transpose_X,
"trans_y": transpose_Y, "trans_y": transpose_Y,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
) )
...@@ -135,17 +129,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -135,17 +129,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
return program_config return program_config
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
# gpu_cpu_map_matmul_v2_to_matmul_pass will affect the type of final fused op
fused_op = "matmul_v2"
input1_dim1 = program_config.inputs["input_data1"].shape[0]
input2_dim1 = program_config.inputs["input_data2"].shape[0]
input1_dim2 = program_config.inputs["input_data1"].shape[1]
input2_dim2 = program_config.inputs["input_data2"].shape[1]
if input1_dim1 == input2_dim1 and input1_dim2 == input2_dim2:
fused_op = "matmul"
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, [fused_op], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -93,12 +93,6 @@ class TestMkldnnMatmulv2Op(MkldnnAutoScanTest): ...@@ -93,12 +93,6 @@ class TestMkldnnMatmulv2Op(MkldnnAutoScanTest):
attrs={ attrs={
"trans_x": kwargs["transpose_X"], "trans_x": kwargs["transpose_X"],
"trans_y": kwargs["transpose_Y"], "trans_y": kwargs["transpose_Y"],
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
) )
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -116,7 +116,7 @@ class TestOneDNNMatmulTransposeReshapeFusePass(PassAutoScanTest): ...@@ -116,7 +116,7 @@ class TestOneDNNMatmulTransposeReshapeFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -154,7 +154,7 @@ class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest): ...@@ -154,7 +154,7 @@ class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['fused_matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册