未验证 提交 a0930484 编写于 作者: S Sławomir Siwek 提交者: GitHub

eltwises + scale fuse pass (#48400)

上级 98aaf797
......@@ -25,7 +25,15 @@ namespace ir {
using string::PrettyLogDetail;
void FuseOperatorScaleOneDNNPass::ApplyImpl(Graph *graph) const {
const std::vector<std::string> fusable_ops{"fc", "matmul", "matmul_v2"};
const std::vector<std::string> fusable_ops{
"fc",
"matmul",
"matmul_v2",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
};
for (const auto &op : fusable_ops) FuseScale(graph, op);
}
......@@ -105,4 +113,8 @@ REGISTER_PASS_CAPABILITY(operator_scale_onednn_fuse_pass)
.EQ("fc", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.LE("elementwise_add", 1)
.LE("elementwise_sub", 1)
.LE("elementwise_mul", 1)
.LE("elementwise_div", 1)
.EQ("scale", 0));
......@@ -116,6 +116,11 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const {
dnnl::post_ops post_operations;
platform::AppendActivation(ctx, post_operations);
if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale");
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
return post_operations;
}
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import ProgramConfig, TensorConfig
class TestElementWiseAddReluFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_config(self, draw):
batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input():
return np.random.random([batch_size, 3, 100, 100]).astype(
np.float32
)
ops_config = [
{
"op_type": "elementwise_add",
"op_inputs": {"X": ["A"], "Y": ["B"]},
"op_outputs": {"Out": ["add_output"]},
"op_attrs": {},
},
{
"op_type": "relu",
"op_inputs": {"X": ["add_output"]},
"op_outputs": {"Out": ["relu_output"]},
"op_attrs": {},
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"A": TensorConfig(data_gen=partial(generate_input)),
"B": TensorConfig(data_gen=partial(generate_input)),
},
outputs=["relu_output"],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["elementwise_add"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False, passes=["elt_act_mkldnn_fuse_pass"], min_success_num=4
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestElementwiseAddActivationOneDNNFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
batch_size = draw(st.sampled_from([1, 32]))
activation_type = draw(
st.sampled_from(
[
'relu',
'gelu',
'swish',
'mish',
'sqrt',
'hard_swish',
'sigmoid',
'abs',
'relu6',
'clip',
'tanh',
'hard_sigmoid',
'leaky_relu',
'scale',
]
)
)
def generate_input():
return np.random.random([batch_size, 3, 100, 100]).astype(
np.float32
)
elementwise_op = OpConfig(
type='elementwise_add',
inputs={'X': ['eltwise_X'], 'Y': ['eltwise_Y']},
outputs={'Out': ['eltwise_output']},
attrs={"use_mkldnn": True},
)
if activation_type == 'relu6':
activation_op = OpConfig(
activation_type,
inputs={'X': ['eltwise_output']},
outputs={'Out': ['activation_output']},
threshold=draw(st.floats(min_value=1.0, max_value=10.0)),
)
elif activation_type == "leaky_relu":
activation_op = OpConfig(
activation_type,
inputs={"X": ["eltwise_output"]},
outputs={"Out": ["activation_output"]},
alpha=draw(st.floats(min_value=0.1, max_value=1.0)),
)
elif activation_type == "scale":
activation_op = OpConfig(
activation_type,
inputs={"X": ["eltwise_output"]},
outputs={"Out": ["activation_output"]},
scale=draw(st.sampled_from([0.125, 0.4, 0.875, 2])),
)
elif activation_type == 'swish':
activation_op = OpConfig(
activation_type,
inputs={'X': ['eltwise_output']},
outputs={'Out': ['activation_output']},
beta=draw(st.floats(min_value=0.1, max_value=1.0)),
)
elif activation_type == 'clip':
activation_op = OpConfig(
activation_type,
inputs={'X': ['eltwise_output']},
outputs={'Out': ['activation_output']},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)),
)
else:
activation_op = OpConfig(
activation_type,
inputs={'X': ['eltwise_output']},
outputs={'Out': ['activation_output']},
)
mini_graph = [elementwise_op, activation_op]
program_config = ProgramConfig(
ops=mini_graph,
weights={},
inputs={
"eltwise_X": TensorConfig(data_gen=partial(generate_input)),
"eltwise_Y": TensorConfig(data_gen=partial(generate_input)),
},
outputs=["activation_output"],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True,
passes=[
'elt_act_mkldnn_fuse_pass',
'operator_scale_onednn_fuse_pass',
],
)
yield config, ['elementwise_add'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False,
passes=[
'elt_act_mkldnn_fuse_pass',
'operator_scale_onednn_fuse_pass',
],
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册