未验证 提交 0420d514 编写于 作者: S Sławomir Siwek 提交者: GitHub

Matmuls with activation and elementwise_add fuses (#44655)

* Add unit tests

* matmul_v2 + activation

* matmuls + elementwise_add

* matmul_v2 postops

* transform matmul to v2

* opcompat

* fix fusing matmul with multipe outs

* add shape constraints

* remove unused vars

* change pass order

* - Unit tests to be debugged

- fix

- refactor

- diagnostic

- more diagnostic

- fix

- Fix number two

- fix

- fix

- fix

- alpha added

- more fixes

- compilation fix

- removed diagnostic code

- cosmetic fixes

* lint

* add alpha constraint

* merge matmul refactor

* trigger CI

* - fix

* - another fix

* code style

* add support for matmul+elementwise_add+activation

* code style

* fix bfloat16 bugs

* change append_binary to append_sum
Co-authored-by: NJacek Czaja <jacek.czaja@intel.com>
上级 5e74ba33
......@@ -212,6 +212,7 @@ if(WITH_MKLDNN)
pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
......
......@@ -2018,6 +2018,33 @@ PDNode *patterns::ElementwiseOp::operator()(
return out_var;
}
PDNode *patterns::MatmulElementwiseAdd::operator()(
const std::string &matmul_type, bool as_x) {
auto matmul_op =
pattern->NewNode(matmul_op_repr())->assert_is_op(matmul_type);
auto matmul_out =
pattern->NewNode(matmul_out_repr())
->AsIntermediate()
->assert_is_op_output(matmul_type, "Out")
->assert_is_only_output_of_op(matmul_type)
->assert_is_op_input("elementwise_add", as_x ? "X" : "Y");
auto elementwise_addend =
pattern->NewNode(elementwise_addend_repr())
->AsInput()
->assert_is_op_input("elementwise_add", as_x ? "Y" : "X");
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_out =
pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
matmul_op->LinksTo({matmul_out});
elementwise_add_op->LinksFrom({matmul_out, elementwise_addend})
.LinksTo({elementwise_add_out});
return elementwise_add_out;
}
PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var,
PDNode *residual_var,
......
......@@ -1038,6 +1038,21 @@ struct ElementwiseOp : public PatternBase {
PATTERN_DECL_NODE(elementwise_out);
};
struct MatmulElementwiseAdd : public PatternBase {
MatmulElementwiseAdd(PDPattern* pattern,
const std::string& name_scope,
const std::string& matmul_type,
bool as_x)
: PatternBase(pattern, name_scope, "matmul_elementwise_add") {}
PDNode* operator()(const std::string& matmul_type, bool as_x);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
PATTERN_DECL_NODE(elementwise_addend);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_out);
};
// Residual Elementwise ops
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
......
......@@ -199,8 +199,11 @@ class DeQuantizer final : public Quanter {
bool IsNotPermittedName(const std::string& output_name) const override {
std::unordered_map<std::string, std::vector<std::string>> block_list{
{"layer_norm",
{"Mean", "Variance"}}, // not used in inference in MKLDNN
{"fc", {"ResidualData"}}}; // artifical output, already dequantized
{"Mean", "Variance"}}, // not used in inference in MKLDNN
{"fc", {"ResidualData"}}, // artifical output, already dequantized
{"matmul", {"ResidualData"}}, // artifical output, already dequantized
{"matmul_v2",
{"ResidualData"}}}; // artifical output, already dequantized
std::vector<std::string> blocked_outputs{"XShape"}; // blocklist for any op
auto op_name = op->Name();
......
......@@ -26,7 +26,7 @@ using string::PrettyLogDetail;
void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
std::vector<std::string> matmul_types = {"matmul"};
auto matmul_types = {"matmul", "matmul_v2"};
for (const auto& matmul_type : matmul_types)
for (auto& act_type : act_types) {
......@@ -88,8 +88,9 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
gpd(graph, handler);
AddStatis(found_matmul_activation_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
PrettyLogDetail("--- fused %d matmul with %s activation",
PrettyLogDetail("--- fused %d %s with %s activation",
found_matmul_activation_count,
matmul_type,
act_type);
}
}
......@@ -102,6 +103,11 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.AddInput("Y")
.IsTensor()
.End()
.AddInput(
"ResidualData") // Extra tensor used in matmul+elementwise_add fuse
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
......@@ -115,6 +121,28 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddInput(
"ResidualData") // Extra tensor used in matmul+elementwise_add fuse
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("abs"))
.AddInput("X")
.IsTensor()
......@@ -267,6 +295,7 @@ REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void MatmulElementwiseAddMKLDNNFusePass::ApplyImpl(Graph* graph) const {
auto matmul_types = {"matmul", "matmul_v2"};
auto matmul_as_x = {true, false};
for (const auto& matmul_type : matmul_types)
for (const auto& as_x : matmul_as_x) {
FuseMatmulElementwiseAdd(graph, matmul_type, as_x);
}
}
void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd(
Graph* graph, const std::string& matmul_type, bool matmul_as_x) const {
const std::string fusion_mode = matmul_as_x ? "x" : "y";
const auto name_scope = matmul_type + "_elementwise_add_as_" + fusion_mode;
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::MatmulElementwiseAdd matmul_pattern(
pattern, name_scope, matmul_type, matmul_as_x);
matmul_pattern(matmul_type, matmul_as_x);
int found_matmul_elementwise_add_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(matmul, matmul_op, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_add, elementwise_add_op, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_addend, elementwise_addend, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_add_out, elementwise_add_out, matmul_pattern);
if (FindFuseOption(*matmul, *elementwise_add) != FUSE_MKLDNN) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "op compat for matmul_elementwise_add_mkldnn_fuse_pass failed.";
return;
}
if (matmul->Op()->HasAttr("ResidualData")) {
LOG(WARNING) << "matmul_elementwise_add can be fused once";
return;
}
matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()});
matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()});
GraphSafeRemoveNodes(g, {matmul_out, elementwise_add});
IR_NODE_LINK_TO(elementwise_addend, matmul);
IR_NODE_LINK_TO(matmul, elementwise_add_out);
found_matmul_elementwise_add_count++;
};
gpd(graph, handler);
AddStatis(found_matmul_elementwise_add_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
PrettyLogDetail("--- fused %d %s (as %s) with elementwise_add",
found_matmul_elementwise_add_count,
matmul_type,
fusion_mode);
}
}
MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 0, 1})
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::MatmulElementwiseAddMKLDNNFusePass);
REGISTER_PASS_CAPABILITY(matmul_elementwise_add_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.LE("elementwise_add", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class MatmulElementwiseAddMKLDNNFusePass : public FusePassBase {
public:
MatmulElementwiseAddMKLDNNFusePass();
virtual ~MatmulElementwiseAddMKLDNNFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
void FuseMatmulElementwiseAdd(Graph* graph,
const std::string& matmul_type,
bool matmul_as_x) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -305,6 +305,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"reshape_transpose_matmul_v2_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
"matmul_elementwise_add_mkldnn_fuse_pass", //
"matmul_activation_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
......@@ -313,7 +315,6 @@ void CpuPassStrategy::EnableMKLDNN() {
"softplus_activation_mkldnn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", //
"matmul_activation_mkldnn_fuse_pass", //
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after
......
......@@ -446,7 +446,6 @@ class MatMulMKLDNNHandler
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
paddle::platform::AppendActivation(ctx, post_operations);
matmul_attrs.set_post_ops(post_operations);
......@@ -698,6 +697,13 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (ctx.HasInput("ResidualData")) {
auto *residual_data = ctx.Input<Tensor>("ResidualData");
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p});
}
auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
......
......@@ -958,6 +958,16 @@ class MatMulV2MKLDNNHandler
matmul_attrs.set_output_scales(0, {alpha});
}
if (ctx.HasInput("ResidualData")) {
auto* residual_data = ctx.Input<Tensor>("ResidualData");
auto residual_data_tz = phi::vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::abcd);
post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md);
}
AppendActivation(ctx, post_operations);
matmul_attrs.set_post_ops(post_operations);
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
import hypothesis.strategies as st
class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
axis = draw(st.sampled_from([-1, 0, 1]))
matmul_as_x = draw(st.booleans())
batch_size = draw(st.integers(min_value=2, max_value=4))
channel = draw(st.sampled_from([16, 32, 64]))
input_dim = draw(st.sampled_from([16, 32, 64]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt',
'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh',
'hard_sigmoid', 'leaky_relu'
]))
def generate_input():
return np.random.random([batch_size, channel, input_dim,
input_dim]).astype(np.float32)
matmul_op = OpConfig(type='matmul',
inputs={
'X': ['matmul_x'],
'Y': ['matmul_y']
},
outputs={'Out': ['matmul_output']},
attrs={
'use_mkldnn': True,
})
if matmul_as_x:
inputs = {'X': ['matmul_output'], 'Y': ['elementwise_addend']}
else:
inputs = {'X': ['elementwise_addend'], 'Y': ['matmul_output']}
elt_add_op = OpConfig(type='elementwise_add',
inputs=inputs,
outputs={'Out': ['elementwise_add_output']},
attrs={
'axis': axis,
'use_mkldnn': True
})
if activation_type == "relu6":
activation_op = OpConfig(activation_type,
inputs={"X": ["elementwise_add_output"]},
outputs={"Out": ["activation_output"]},
threshold=draw(
st.floats(min_value=1.0,
max_value=10.0)))
elif activation_type == "leaky_relu":
activation_op = OpConfig(activation_type,
inputs={"X": ["elementwise_add_output"]},
outputs={"Out": ["activation_output"]},
alpha=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == "swish":
activation_op = OpConfig(activation_type,
inputs={"X": ["elementwise_add_output"]},
outputs={"Out": ["activation_output"]},
beta=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == "clip":
activation_op = OpConfig(
activation_type,
inputs={"X": ["elementwise_add_output"]},
outputs={"Out": ["activation_output"]},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)))
else:
activation_op = OpConfig(activation_type,
inputs={"X": ["elementwise_add_output"]},
outputs={"Out": ["activation_output"]})
model_net = [matmul_op, elt_add_op, activation_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
'matmul_x': TensorConfig(data_gen=partial(generate_input)),
'matmul_y': TensorConfig(data_gen=partial(generate_input)),
'elementwise_addend':
TensorConfig(data_gen=partial(generate_input))
},
outputs=['activation_output'])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True,
passes=[
'matmul_elementwise_add_mkldnn_fuse_pass',
'matmul_activation_mkldnn_fuse_pass'
])
yield config, ['matmul'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
passes=[
'matmul_elementwise_add_mkldnn_fuse_pass',
'matmul_activation_mkldnn_fuse_pass'
])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
import hypothesis.strategies as st
class TestMatmulElementwiseAddMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
axis = draw(st.sampled_from([-1, 0, 1]))
matmul_as_x = draw(st.booleans())
batch_size = draw(st.integers(min_value=2, max_value=4))
channel = draw(st.sampled_from([16, 32, 64]))
input_dim = draw(st.sampled_from([16, 32, 64]))
def generate_input():
return np.random.random([batch_size, channel, input_dim,
input_dim]).astype(np.float32)
matmul_op = OpConfig(type='matmul',
inputs={
'X': ['matmul_x'],
'Y': ['matmul_y']
},
outputs={'Out': ['matmul_output']},
attrs={
'use_mkldnn': True,
})
if matmul_as_x:
inputs = {'X': ['matmul_output'], 'Y': ['elementwise_addend']}
else:
inputs = {'X': ['elementwise_addend'], 'Y': ['matmul_output']}
elt_add_op = OpConfig(type='elementwise_add',
inputs=inputs,
outputs={'Out': ['elementwise_add_output']},
attrs={
'axis': axis,
'use_mkldnn': True
})
model_net = [matmul_op, elt_add_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
'matmul_x': TensorConfig(data_gen=partial(generate_input)),
'matmul_y': TensorConfig(data_gen=partial(generate_input)),
'elementwise_addend':
TensorConfig(data_gen=partial(generate_input))
},
outputs=['elementwise_add_output'])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True, passes=['matmul_elementwise_add_mkldnn_fuse_pass'])
yield config, ['matmul'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
passes=['matmul_elementwise_add_mkldnn_fuse_pass'])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
import hypothesis.strategies as st
class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
transpose_X = draw(st.booleans())
transpose_Y = draw(st.booleans())
batch_size = draw(st.integers(min_value=2, max_value=4))
channel = draw(st.sampled_from([16, 32, 64]))
input_dim = draw(st.sampled_from([16, 32, 64]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt',
'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh',
'hard_sigmoid', 'leaky_relu'
]))
def generate_input(type):
broadcast_X = st.booleans()
channel_X = 1 if broadcast_X else channel
channel_Y = channel if broadcast_X else 1
batch_size_X = 1 if broadcast_X else batch_size
batch_size_Y = batch_size if broadcast_X else 1
if transpose_X and transpose_Y:
shape_x = [batch_size_X, channel_X, input_dim, 32]
shape_y = [batch_size_Y, channel_Y, 64, input_dim]
elif transpose_X:
shape_x = [batch_size_X, channel_X, input_dim, 32]
shape_y = [batch_size_Y, channel_Y, input_dim, 64]
elif transpose_Y:
shape_x = [batch_size_X, channel_X, 32, input_dim]
shape_y = [batch_size_Y, channel_Y, 8, input_dim]
else:
shape_x = [batch_size_X, channel_X, 32, input_dim]
shape_y = [batch_size_Y, channel_Y, input_dim, 16]
if type == 'X':
return np.random.random(shape_x).astype(np.float32)
else:
return np.random.random(shape_y).astype(np.float32)
matmul_op = OpConfig(type='matmul_v2',
inputs={
'X': ['matmul_X'],
'Y': ['matmul_Y']
},
outputs={'Out': ['matmul_output']},
attrs={
'trans_x': transpose_X,
'trans_y': transpose_Y
})
if activation_type == 'relu6':
activation_op = OpConfig(activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
threshold=draw(
st.floats(min_value=1.0,
max_value=10.0)))
elif activation_type == 'leaky_relu':
activation_op = OpConfig(activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
alpha=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == 'swish':
activation_op = OpConfig(activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
beta=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == 'clip':
activation_op = OpConfig(
activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)))
else:
activation_op = OpConfig(activation_type,
inputs={'X': ['matmul_output']},
outputs={'Out': ['activation_output']})
model_net = [matmul_op, activation_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
'matmul_X': TensorConfig(data_gen=partial(generate_input, 'X')),
'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'Y'))
},
outputs=['activation_output'])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul_v2'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
max_examples=30,
passes=['matmul_activation_mkldnn_fuse_pass'])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
import hypothesis.strategies as st
class TestMatmulV2ElementwiseAddMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
axis = draw(st.sampled_from([-1, 0, 1]))
matmul_as_x = draw(st.booleans())
batch_size = draw(st.integers(min_value=2, max_value=4))
channel = draw(st.sampled_from([16, 32, 64]))
input_dim_shared = draw(st.sampled_from([16, 32, 64]))
input_dim_X = draw(st.sampled_from([16, 32, 64]))
input_dim_Y = draw(st.sampled_from([16, 32, 64]))
def generate_input(type):
broadcast_X = st.booleans()
channel_X = 1 if broadcast_X else channel
channel_Y = channel if broadcast_X else 1
batch_size_X = 1 if broadcast_X else batch_size
batch_size_Y = batch_size if broadcast_X else 1
shape_x = [batch_size_X, channel_X, input_dim_X, input_dim_shared]
shape_y = [batch_size_Y, channel_Y, input_dim_shared, input_dim_Y]
if type == 'X':
return np.random.random(shape_x).astype(np.float32)
elif type == 'Y':
return np.random.random(shape_y).astype(np.float32)
else:
shape_out = [batch_size, channel, input_dim_X, input_dim_Y]
return np.random.random(shape_out).astype(np.float32)
matmul_op = OpConfig(type='matmul_v2',
inputs={
'X': ['matmul_X'],
'Y': ['matmul_Y']
},
outputs={'Out': ['matmul_output']},
attrs={'use_mkldnn': True})
if matmul_as_x:
inputs = {'X': ['matmul_output'], 'Y': ['elementwise_addend']}
else:
inputs = {'X': ['elementwise_addend'], 'Y': ['matmul_output']}
elt_add_op = OpConfig(type='elementwise_add',
inputs=inputs,
outputs={'Out': ['elementwise_add_output']},
attrs={
'axis': axis,
'use_mkldnn': True
})
model_net = [matmul_op, elt_add_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
'matmul_X':
TensorConfig(data_gen=partial(generate_input, 'X')),
'matmul_Y':
TensorConfig(data_gen=partial(generate_input, 'Y')),
'elementwise_addend':
TensorConfig(data_gen=partial(generate_input, 'ElAdd'))
},
outputs=['elementwise_add_output'])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ['matmul_v2'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
max_examples=30,
passes=['matmul_elementwise_add_mkldnn_fuse_pass'])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册