未验证 提交 3333a439 编写于 作者: S Sławomir Siwek 提交者: GitHub

matmul+activation fuse pass (#43519)

* add method for post ops

* format code

* gpd

* format style

* add matmul+act test

* implement matmul+activation

* whitespaces

* code style

* python code format

* Increase UT timeout

* code format

* update style

* generalize activation fuse passes

* change order

* Unify activation GPD

* Revert changes with op_act

* remove softmax mkldnn attrs

* set common name for act attributes

* whitespace

* append postops by helper function

* ut style

* revert changes related to quantization

* Reduce redundancy

* reduce number of parameters

* trigger CI

* validate attribute

* trim unit test
上级 40b68630
......@@ -212,6 +212,7 @@ if(WITH_MKLDNN)
pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
......
......@@ -26,7 +26,6 @@ using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
std::vector<std::string> conv_types = {"conv2d"};
for (const auto& conv_type : conv_types)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
std::vector<std::string> matmul_types = {"matmul"};
for (const auto& matmul_type : matmul_types)
for (auto& act_type : act_types) {
FuseMatmulAct(graph, matmul_type, act_type);
}
}
void MatmulActivationMkldnnFusePass::FuseMatmulAct(
Graph* graph, const std::string& matmul_type, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(matmul_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorActivation matmul_act_pattern(
gpd.mutable_pattern(), "matmul_activation_mkldnn_fuse");
matmul_act_pattern(matmul_type, act_type);
int found_matmul_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle " + matmul_type + "+" + act_type + " fuse";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "matmul_activation_mkldnn_fuse_pass op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(matmul, preceding_op, matmul_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, preceding_op_out, matmul_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, matmul_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, matmul_act_pattern);
OpDesc* matmul_op = matmul->Op();
OpDesc* act_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
matmul_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type = BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
}
matmul_op->SetAttr("fuse_activation", act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});
IR_NODE_LINK_TO(matmul, activation_out);
GraphSafeRemoveNodes(graph, {activation, matmul_out});
found_matmul_activation_count++;
};
gpd(graph, handler);
AddStatis(found_matmul_activation_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
PrettyLogDetail("--- fused %d matmul with %s activation",
found_matmul_activation_count,
act_type);
}
}
MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("abs"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("clip"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("min")
.End()
.AddAttr("max")
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("hard_sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("slope")
.IsOptional()
.IsType<float>()
.End()
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();
AddOpCompat(OpCompat("hard_swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("threshold")
.IsOptional()
.IsType<float>()
.End()
.AddAttr("scale")
.IsOptional()
.IsType<float>()
.End()
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();
AddOpCompat(OpCompat("leaky_relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End();
AddOpCompat(OpCompat("mish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("relu6"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("threshold")
.IsType<float>()
.End();
AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("sqrt"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("beta")
.IsType<float>()
.End();
AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_activation_mkldnn_fuse_pass,
paddle::framework::ir::MatmulActivationMkldnnFusePass);
REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class MatmulActivationMkldnnFusePass : public FusePassBase {
public:
MatmulActivationMkldnnFusePass();
virtual ~MatmulActivationMkldnnFusePass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseMatmulAct(Graph *graph,
const std::string &matmul_type,
std::string &act_type) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -50,9 +50,9 @@ void MainTest(const std::string& activation_type) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto activation_type =
BOOST_GET_CONST(std::string, op->GetAttr("fuse_activation_type"));
BOOST_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(activation_type.compare(activation_type), 0);
}
}
......
......@@ -302,6 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"softplus_activation_mkldnn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", //
"matmul_activation_mkldnn_fuse_pass", //
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <tuple>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
using dnnl::memory;
using dnnl::primitive;
......@@ -453,6 +454,8 @@ class MatMulMKLDNNHandler
matmul_attrs.set_output_scales(0, {scale_out});
}
paddle::platform::AppendActivation(ctx, post_operations);
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
import hypothesis.strategies as st
class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
transpose_X = draw(st.booleans())
transpose_Y = draw(st.booleans())
alpha = draw(st.sampled_from([1, 2]))
batch_size = draw(st.sampled_from([4]))
channel = draw(st.sampled_from([8]))
input_dim = draw(st.sampled_from([32]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt',
'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh',
'hard_sigmoid', 'leaky_relu'
]))
def generate_input(type):
if transpose_X and transpose_Y:
shape_x = [batch_size, channel, input_dim, 32]
shape_y = [batch_size, channel, 64, input_dim]
elif transpose_X:
shape_x = [batch_size, channel, input_dim, 32]
shape_y = [batch_size, channel, input_dim, 64]
elif transpose_Y:
shape_x = [batch_size, channel, 32, input_dim]
shape_y = [batch_size, channel, 8, input_dim]
else:
shape_x = [batch_size, channel, 32, input_dim]
shape_y = [batch_size, channel, input_dim, 16]
if type == 'x':
return np.random.random(shape_x).astype(np.float32)
else:
return np.random.random(shape_y).astype(np.float32)
matmul_op = OpConfig(type='matmul',
inputs={
'X': ['matmul_X'],
'Y': ['matmul_Y']
},
outputs={'Out': ['matmul_output']},
attrs={
'transpose_X': transpose_X,
'transpose_Y': transpose_Y,
'alpha': alpha
})
if activation_type == "relu6":
activation_op = OpConfig(activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
threshold=draw(
st.floats(min_value=1.0,
max_value=10.0)))
elif activation_type == "leaky_relu":
activation_op = OpConfig(activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
alpha=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == "swish":
activation_op = OpConfig(activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
beta=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == "clip":
activation_op = OpConfig(
activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)))
else:
activation_op = OpConfig(activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]})
model_net = [matmul_op, activation_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
'matmul_X': TensorConfig(data_gen=partial(generate_input, 'x')),
'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'y'))
},
outputs=['activation_output'])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
max_examples=30,
passes=['matmul_activation_mkldnn_fuse_pass'])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册