未验证 提交 84beef80 编写于 作者: H Hulek 提交者: GitHub

Rewrite conv activation mkldnn fuse pass tester (#49278)

* Done

* Deleted old python test, fixed new python test, changed names in parallel_UT

* Revert parallel UT changes

* Revert parallel UT changes v2

* Review fixes and simplification of conv output shape calculation, disabled sqrt from conv_act_duse_pass

* delete sqrt from possible activations from conv_concat_relu test

* review refactor

* merge main

* delete sqrt from list of compatible activations

* Test with no outdated inputs
上级 4632ca13
......@@ -398,10 +398,6 @@ if(WITH_MKLDNN)
test_depthwise_conv_mkldnn_pass
SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc
DEPS depthwise_conv_mkldnn_pass)
cc_test(
test_conv_activation_mkldnn_fuse_pass
SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc
DEPS conv_activation_mkldnn_fuse_pass)
cc_test_old(
test_int8_scale_calculation_mkldnn_pass SRCS
mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc DEPS
......
......@@ -26,6 +26,8 @@ using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = GetSupportedActivations();
act_types.erase(std::remove(act_types.begin(), act_types.end(), "sqrt"),
act_types.end());
std::vector<std::string> conv_types = {"fused_conv2d", "conv2d"};
for (auto& act_type : act_types) {
......@@ -363,14 +365,6 @@ ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() {
.IsTensor()
.End();
AddOpCompat(OpCompat("sqrt"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("abs"))
.AddInput("X")
.IsTensor()
......@@ -402,6 +396,5 @@ REGISTER_PASS_CAPABILITY(conv_activation_mkldnn_fuse_pass)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog,
const std::string& type,
const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
bool is_activation = false,
bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("groups", 1);
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("data_format", std::string("NCHW"));
op->SetAttr("strides", std::vector<int>({1, 1}));
op->SetAttr("dilations", std::vector<int>({1, 1}));
op->SetAttr("paddings", std::vector<int>({0, 0}));
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", outputs);
} else if (is_activation) {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs);
if (type == "leaky_relu") {
op->SetAttr("alpha", 0.02f);
} else if (type == "relu6") {
op->SetAttr("threshold", 6.0f);
} else if (type == "mish") {
op->SetAttr("threshold", 20.0f);
} else if (type == "swish") {
op->SetAttr("beta", 1.0f);
}
op->SetOutput("Out", outputs);
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
// b->OP1->c
// (c, weights, bias)->conv->f
// (f)->activation->g
ProgramDesc BuildProgramDesc(std::string activation) {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a",
"b",
"c",
"weights",
"bias",
"f",
"g",
"h",
"weights2",
"bias2",
"k",
"l",
"m"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias" || v == "weights2" || v == "bias2") {
var->SetPersistable(true);
}
}
SetOp(&prog,
"OP0",
"op0",
std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog,
"OP1",
"op1",
std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
// conv+activation, both with MKL-DNN
SetOp(&prog,
"conv2d",
"conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}),
false,
true);
SetOp(&prog,
activation,
"activation1",
std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}),
true,
true);
SetOp(&prog,
"OP3",
"op3",
std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+activation, only one with MKL-DNN
SetOp(&prog,
"conv2d",
"conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}),
false,
true);
SetOp(&prog,
"activation",
"activation2",
std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}),
true,
false);
SetOp(&prog,
"OP4",
"op4",
std::vector<std::string>({"l"}),
std::vector<std::string>({"m"}));
return prog;
}
void MainTest(std::string activation) {
auto prog = BuildProgramDesc(activation);
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_activation_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: CONV, activation, conv_out
// Add 1 Node: ConvActivation
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_activation op in newly generated graph
int conv_activation_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && (node->Op()->Type() == "conv2d" ||
node->Op()->Type() == "fused_conv2d")) {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
auto op_name = PADDLE_GET_CONST(std::string, op->GetAttr("name"));
if (op->GetAttrIfExists<std::string>("fuse_activation") == activation) {
++conv_activation_count;
}
// check if only "conv1" convolution is fused
if (op_name == "conv1") {
ASSERT_TRUE(op->HasAttr("fuse_activation"));
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_activation"));
}
}
}
EXPECT_EQ(conv_activation_count, 1);
}
TEST(ConvActivationFusePass, conv_relu_fuse_pass) { MainTest("relu"); }
TEST(ConvActivationFusePass, conv_leaky_relu_fuse_pass) {
MainTest("leaky_relu");
}
TEST(ConvActivationFusePass, conv_relu6_fuse_pass) { MainTest("relu6"); }
TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); }
TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) {
MainTest("hard_swish");
}
TEST(ConvActivationFusePass, conv_mish_fuse_pass) { MainTest("mish"); }
TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) {
MainTest("hard_sigmoid");
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_activation_mkldnn_fuse_pass);
......@@ -206,7 +206,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
PROPERTIES TIMEOUT 60)
set_tests_properties(test_adaptive_pool2d_convert_global_pass_autoscan
PROPERTIES TIMEOUT 100)
set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_act_onednn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_elementwise_add2_act_fuse_pass
PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_elementwise_add_act_fuse_pass
......@@ -260,7 +260,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_mkldnn_mish_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_conv3d_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_act_onednn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_transpose_eltwiseadd_bn_fuse_pass
PROPERTIES TIMEOUT 250)
set_tests_properties(test_onednn_matmul_transpose_reshape_fuse_pass
......
......@@ -19,67 +19,45 @@ from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestConvActMkldnnFusePass(PassAutoScanTest):
r"""
x_var f_var(persistable)
\ /
conv2d
|
conv2d_var
|
act
|
act_var
"""
class TestConvActOneDNNFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
# MKLDNN
config = self.create_inference_config(use_gpu=False)
config.enable_mkldnn()
yield config, ["conv2d"], (1e-4, 1e-5)
yield config, ['fused_conv2d'], (1e-4, 1e-5)
def is_program_valid(self, prog_config):
paddings = prog_config.ops[0].attrs["paddings"]
strides = prog_config.ops[0].attrs["strides"]
groups = prog_config.ops[0].attrs["groups"]
padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"]
dilations = prog_config.ops[0].attrs["dilations"]
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape
if padding_algorithm == "VALID":
paddings = prog_config.ops[0].attrs['paddings']
groups = prog_config.ops[0].attrs['groups']
padding_algorithm = prog_config.ops[0].attrs['padding_algorithm']
dilations = prog_config.ops[0].attrs['dilations']
data_format = prog_config.ops[0].attrs['data_format']
filter_shape = prog_config.weights['filter'].shape
input_shape = prog_config.inputs['input_x'].shape
height = input_shape[data_format.index('H')]
width = input_shape[data_format.index('W')]
if padding_algorithm == 'VALID':
if (
(input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1))
/ strides[0]
+ 1
) <= 1 or (
(input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1))
/ strides[1]
+ 1
) <= 1:
height - (dilations[0] * (filter_shape[2] - 1) + 1) <= 0
or width - (dilations[1] * (filter_shape[3] - 1) + 1) <= 0
):
return False
if padding_algorithm == "EXPLICIT":
if padding_algorithm == 'EXPLICIT':
if (
(
input_shape[2]
+ paddings[0]
+ paddings[1]
- (dilations[0] * (filter_shape[2] - 1) + 1)
)
/ strides[0]
+ 1
) <= 1 or (
(
input_shape[3]
+ paddings[2]
+ paddings[3]
- (dilations[1] * (filter_shape[3] - 1) + 1)
)
/ strides[1]
+ 1
) <= 1:
height
+ paddings[0]
+ paddings[1]
- (dilations[0] * (filter_shape[2] - 1) + 1)
<= 0
or width
+ paddings[2]
+ paddings[3]
- (dilations[1] * (filter_shape[3] - 1) + 1)
<= 0
):
return False
if data_format == "NCHW":
if data_format == 'NCHW':
if input_shape[1] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
......@@ -89,6 +67,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
return False
if filter_shape[0] % groups != 0:
return False
return True
def sample_program_config(self, draw):
......@@ -101,7 +80,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
x_shape[1] = draw(st.integers(min_value=5, max_value=10))
# 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
data_format = draw(st.sampled_from(['NCHW', 'NHWC']))
# 3. Generate legal shape of input:Y of conv2d
f_shape = draw(
......@@ -109,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
st.integers(min_value=1, max_value=5), min_size=4, max_size=4
)
)
if data_format == "NCHW":
if data_format == 'NCHW':
f_shape[1] = x_shape[1]
else:
f_shape[1] = x_shape[3]
......@@ -122,7 +101,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
)
# 5. Generate legal attr:padding_algorithm of conv2d
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
padding_algorithm = draw(st.sampled_from(['EXPLICIT', 'SAME', 'VALID']))
# 6. Generate legal attr:padding of conv2d
padding = draw(
......@@ -141,53 +120,88 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
)
)
# 9. Generate legal input:ResidualData of conv2d
res_shape = []
if draw(st.booleans()):
res_shape = draw(
st.lists(
st.integers(min_value=1, max_value=100),
min_size=4,
max_size=4,
)
)
# 10. Generate legal shape of input:bias of conv2d
conv_bias_shape = []
inputs = dict()
weights = dict()
use_mkldnn = None
if draw(st.booleans()):
conv_bias_shape = [f_shape[0]]
inputs = {
"Input": ["input_x"],
"Filter": ["filter"],
"ResidualData": ["residualdata"],
"Bias": ["conv_bias"],
'Input': ['input_x'],
'Filter': ['filter'],
}
weights = {
"filter": TensorConfig(shape=f_shape),
"conv_bias": TensorConfig(shape=conv_bias_shape),
'filter': TensorConfig(shape=f_shape),
}
use_mkldnn = True
else:
inputs = {
"Input": ["input_x"],
"Filter": ["filter"],
"ResidualData": ["residualdata"],
'Input': ['input_x'],
'Filter': ['filter'],
}
weights = {"filter": TensorConfig(shape=f_shape)}
use_mkldnn = False
weights = {'filter': TensorConfig(shape=f_shape)}
# 11. Generate legal act type of conv2d
act_type = draw(
st.sampled_from(["relu", "leaky_relu", "relu6", "swish"])
st.sampled_from(
[
'abs',
'clip',
'gelu',
'hard_sigmoid',
'hard_swish',
'leaky_relu',
'mish',
'relu',
'relu6',
'sigmoid',
'swish',
'tanh',
]
)
)
# 12. Generate legal attr of act
act_op = None
self.passes = ['conv_activation_mkldnn_fuse_pass']
if act_type == 'relu6':
act_op = OpConfig(
'relu6',
inputs={'X': ['conv2d_out']},
outputs={'Out': ['relu_out']},
threshold=draw(st.floats(min_value=1.0, max_value=10.0)),
)
elif act_type == 'leaky_relu':
act_op = OpConfig(
'leaky_relu',
inputs={'X': ['conv2d_out']},
outputs={'Out': ['relu_out']},
alpha=draw(st.floats(min_value=0.1, max_value=1.0)),
)
elif act_type == 'swish':
act_op = OpConfig(
'swish',
inputs={'X': ['conv2d_out']},
outputs={'Out': ['swish_out']},
beta=draw(st.floats(min_value=0.1, max_value=1.0)),
)
elif act_type == 'clip':
act_op = OpConfig(
'clip',
inputs={'X': ['conv2d_out']},
outputs={'Out': ['clip_out']},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)),
)
else:
act_op = OpConfig(
act_type,
inputs={'X': ['conv2d_out']},
outputs={'Out': ['activation_output']},
)
# 13. Create conv2d op
conv2d_op = OpConfig(
"conv2d",
'conv2d',
inputs=inputs,
outputs={"Output": ["conv2d_out"]},
outputs={'Output': ['conv2d_out']},
strides=strides,
padding_algorithm=padding_algorithm,
paddings=padding,
......@@ -197,50 +211,15 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
use_mkldnn=True,
)
# 11. Generate legal attr of act
act_op = None
self.passes = ["conv_activation_mkldnn_fuse_pass"]
if act_type == "relu6":
threshold = draw(st.floats(min_value=1.0, max_value=10.0))
act_op = OpConfig(
"relu6",
inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]},
threshold=threshold,
)
elif act_type == "leaky_relu":
alpha = draw(st.floats(min_value=0.1, max_value=1.0))
act_op = OpConfig(
"leaky_relu",
inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]},
alpha=alpha,
)
elif act_type == "relu":
act_op = OpConfig(
"relu",
inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]},
)
elif act_type == "swish":
beta = draw(st.floats(min_value=0.1, max_value=1.0))
act_op = OpConfig(
"swish",
inputs={"X": ["conv2d_out"]},
outputs={"Out": ["swish_out"]},
beta=beta,
)
ops = [conv2d_op, act_op]
program_config = ProgramConfig(
ops=ops,
weights=weights,
inputs={
"input_x": TensorConfig(shape=x_shape),
"residualdata": TensorConfig(shape=res_shape),
'input_x': TensorConfig(shape=x_shape),
},
outputs=ops[-1].outputs["Out"],
outputs=ops[-1].outputs['Out'],
)
return program_config
......@@ -248,5 +227,5 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
self.run_and_statis(quant=False, max_examples=300, passes=self.passes)
if __name__ == "__main__":
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle
import paddle.fluid as fluid
from paddle.fluid.core import PassVersionChecker
class ConvActivationMkldnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32"
)
conv_out = paddle.static.nn.conv2d(
data,
num_filters=self.conv_num_filters,
filter_size=self.conv_filter_size,
bias_attr=self.conv_bias_attr,
act=self.act,
)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [conv_out]
self.enable_mkldnn = True
self.pass_name = 'conv_activation_mkldnn_fuse_pass'
def set_params(self):
self.conv_num_filters = 3
self.conv_filter_size = 3
self.conv_bias_attr = False
self.act = "relu"
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
def test_pass_compatible(self):
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class ConvActivationMkldnnFusePassTest_1(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "relu"
class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 3
self.conv_filter_size = 3
self.conv_bias_attr = False
self.act = "leaky_relu"
class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "leaky_relu"
class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 3
self.conv_filter_size = 3
self.conv_bias_attr = False
self.act = "relu6"
class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "hard_swish"
class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "mish"
class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "hard_sigmoid"
if __name__ == "__main__":
unittest.main()
......@@ -37,7 +37,6 @@ class TestOneDNNConvConcatActivationFusePass(PassAutoScanTest):
'gelu',
'swish',
'mish',
'sqrt',
'hard_swish',
'sigmoid',
'abs',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册