未验证 提交 99c872fa 编写于 作者: S Sławomir Siwek 提交者: GitHub

FC/matmul(v2) + scale fuse pass (#47420)

上级 559b9754
......@@ -218,6 +218,7 @@ if(WITH_MKLDNN)
pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorScaleOneDNNPass::ApplyImpl(Graph *graph) const {
const std::vector<std::string> fusable_ops{"fc", "matmul", "matmul_v2"};
for (const auto &op : fusable_ops) FuseScale(graph, op);
}
void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph,
const std::string &op_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_scale_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorActivation op_scale_pattern(
gpd.mutable_pattern(), op_type + "_scale_onednn_fuse_pass");
op_scale_pattern(op_type, "scale");
int found_operator_scale_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(operator_out, preceding_op_out, op_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, activation, op_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, activation_out, op_scale_pattern);
if (operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn")))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with scale.";
return;
}
if (scale_op->Op()->GetAttrIfExists<float>("bias") != 0.0) {
VLOG(4) << op_type << " can be fused only with unbiased scale.";
return;
}
float scale = PADDLE_GET_CONST(float, scale_op->Op()->GetAttr("scale"));
auto *scope = param_scope();
auto const &names = scale_op->Op()->InputNames();
bool has_scale_tensor =
std::find(names.begin(), names.end(), "ScaleTensor") != names.end();
if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) {
std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front();
auto *scale_var = scope->FindVar(scale_var_name);
// ScaleTensor must be weight
if (scale_var == nullptr) return;
auto *scale_tensor = scale_var->GetMutable<LoDTensor>();
scale = *(scale_tensor->data<float>());
}
operator_op->Op()->SetAttr("fused_output_scale", scale);
operator_op->Op()->SetOutput("Out", {scale_out->Name()});
IR_OP_VAR_LINK(operator_op, scale_out);
GraphSafeRemoveNodes(g, {scale_op, operator_out});
found_operator_scale_count++;
};
gpd(graph, handler);
AddStatis(found_operator_scale_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_scale_count > 0)
PrettyLogDetail(
"--- fused %d %s with scale", found_operator_scale_count, op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_scale_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorScaleOneDNNPass);
REGISTER_PASS_CAPABILITY(operator_scale_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("scale", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorScaleOneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorScaleOneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseScale(Graph *graph, const std::string &op_type) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -326,6 +326,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"softplus_activation_mkldnn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", //
"operator_scale_onednn_fuse_pass", //
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after
......@@ -419,6 +420,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("scale_matmul_fuse_pass");
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("operator_scale_onednn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass");
......
......@@ -533,6 +533,12 @@ class FCPrimitiveFactory {
scale, dnnl::algorithm::eltwise_hardswish, alpha, beta);
}
if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale");
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
attributes.set_post_ops(post_operations);
return attributes;
}
......
......@@ -250,6 +250,12 @@ class MatMulV2MKLDNNHandler
AppendActivation(ctx, post_operations);
if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale");
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
......
......@@ -21,7 +21,6 @@ import hypothesis.strategies as st
class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
transpose_X = draw(st.booleans())
transpose_Y = draw(st.booleans())
......@@ -30,11 +29,25 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
channel = draw(st.sampled_from([8]))
input_dim = draw(st.sampled_from([32]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish',
'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid',
'leaky_relu'
]))
st.sampled_from(
[
'relu',
'gelu',
'swish',
'mish',
'sqrt',
'hard_swish',
'sigmoid',
'abs',
'relu6',
'clip',
'tanh',
'hard_sigmoid',
'leaky_relu',
'scale',
]
)
)
def generate_input(type):
if transpose_X and transpose_Y:
......@@ -55,50 +68,60 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
else:
return np.random.random(shape_y).astype(np.float32)
matmul_op = OpConfig(type='matmul',
inputs={
'X': ['matmul_X'],
'Y': ['matmul_Y']
},
outputs={'Out': ['matmul_output']},
attrs={
'transpose_X': transpose_X,
'transpose_Y': transpose_Y,
'alpha': alpha
})
matmul_op = OpConfig(
type='matmul',
inputs={'X': ['matmul_X'], 'Y': ['matmul_Y']},
outputs={'Out': ['matmul_output']},
attrs={
'transpose_X': transpose_X,
'transpose_Y': transpose_Y,
'alpha': alpha,
'use_mkldnn': True,
},
)
if activation_type == "relu6":
activation_op = OpConfig(activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
threshold=draw(
st.floats(min_value=1.0,
max_value=10.0)))
activation_op = OpConfig(
activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
threshold=draw(st.floats(min_value=1.0, max_value=10.0)),
)
elif activation_type == "leaky_relu":
activation_op = OpConfig(activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
alpha=draw(
st.floats(min_value=0.1,
max_value=1.0)))
activation_op = OpConfig(
activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
alpha=draw(st.floats(min_value=0.1, max_value=1.0)),
)
elif activation_type == "scale":
activation_op = OpConfig(
activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
scale=draw(st.sampled_from([0.125, 0.4, 0.875, 2])),
)
elif activation_type == "swish":
activation_op = OpConfig(activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
beta=draw(
st.floats(min_value=0.1,
max_value=1.0)))
activation_op = OpConfig(
activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
beta=draw(st.floats(min_value=0.1, max_value=1.0)),
)
elif activation_type == "clip":
activation_op = OpConfig(
activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)))
max=draw(st.floats(min_value=0.5, max_value=1.0)),
)
else:
activation_op = OpConfig(activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]})
activation_op = OpConfig(
activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
)
model_net = [matmul_op, activation_op]
......@@ -107,20 +130,32 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
weights={},
inputs={
'matmul_X': TensorConfig(data_gen=partial(generate_input, 'x')),
'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'y'))
'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'y')),
},
outputs=['activation_output'])
outputs=['activation_output'],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
config = self.create_inference_config(
use_mkldnn=True,
passes=[
'matmul_activation_mkldnn_fuse_pass',
'operator_scale_onednn_fuse_pass',
],
)
yield config, ['matmul'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
max_examples=30,
passes=['matmul_activation_mkldnn_fuse_pass'])
self.run_and_statis(
quant=False,
max_examples=50,
passes=[
'matmul_activation_mkldnn_fuse_pass',
'operator_scale_onednn_fuse_pass',
],
)
if __name__ == '__main__':
......
......@@ -21,7 +21,6 @@ import hypothesis.strategies as st
class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
transpose_X = draw(st.booleans())
transpose_Y = draw(st.booleans())
......@@ -29,11 +28,25 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
channel = draw(st.sampled_from([16, 32, 64]))
input_dim = draw(st.sampled_from([16, 32, 64]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish',
'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid',
'leaky_relu'
]))
st.sampled_from(
[
'relu',
'gelu',
'swish',
'mish',
'sqrt',
'hard_swish',
'sigmoid',
'abs',
'relu6',
'clip',
'tanh',
'hard_sigmoid',
'leaky_relu',
'scale',
]
)
)
def generate_input(type):
broadcast_X = st.booleans()
......@@ -60,49 +73,59 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
else:
return np.random.random(shape_y).astype(np.float32)
matmul_op = OpConfig(type='matmul_v2',
inputs={
'X': ['matmul_X'],
'Y': ['matmul_Y']
},
outputs={'Out': ['matmul_output']},
attrs={
'trans_x': transpose_X,
'trans_y': transpose_Y
})
matmul_op = OpConfig(
type='matmul_v2',
inputs={'X': ['matmul_X'], 'Y': ['matmul_Y']},
outputs={'Out': ['matmul_output']},
attrs={
'trans_x': transpose_X,
'trans_y': transpose_Y,
'use_mkldnn': True,
},
)
if activation_type == 'relu6':
activation_op = OpConfig(activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
threshold=draw(
st.floats(min_value=1.0,
max_value=10.0)))
activation_op = OpConfig(
activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
threshold=draw(st.floats(min_value=1.0, max_value=10.0)),
)
elif activation_type == 'leaky_relu':
activation_op = OpConfig(activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
alpha=draw(
st.floats(min_value=0.1,
max_value=1.0)))
activation_op = OpConfig(
activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
alpha=draw(st.floats(min_value=0.1, max_value=1.0)),
)
elif activation_type == "scale":
activation_op = OpConfig(
activation_type,
inputs={"X": ["matmul_output"]},
outputs={"Out": ["activation_output"]},
scale=draw(st.sampled_from([0.125, 0.4, 0.875, 2])),
)
elif activation_type == 'swish':
activation_op = OpConfig(activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
beta=draw(
st.floats(min_value=0.1,
max_value=1.0)))
activation_op = OpConfig(
activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
beta=draw(st.floats(min_value=0.1, max_value=1.0)),
)
elif activation_type == 'clip':
activation_op = OpConfig(
activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)))
max=draw(st.floats(min_value=0.5, max_value=1.0)),
)
else:
activation_op = OpConfig(activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']})
activation_op = OpConfig(
activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
)
model_net = [matmul_op, activation_op]
......@@ -111,20 +134,32 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
weights={},
inputs={
'matmul_X': TensorConfig(data_gen=partial(generate_input, 'X')),
'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'Y'))
'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'Y')),
},
outputs=['activation_output'])
outputs=['activation_output'],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
config = self.create_inference_config(
use_mkldnn=True,
passes=[
'matmul_activation_mkldnn_fuse_pass',
'operator_scale_onednn_fuse_pass',
],
)
yield config, ['matmul_v2'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
max_examples=30,
passes=['matmul_activation_mkldnn_fuse_pass'])
self.run_and_statis(
quant=False,
max_examples=50,
passes=[
'matmul_activation_mkldnn_fuse_pass',
'operator_scale_onednn_fuse_pass',
],
)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册