未验证 提交 96597a85 编写于 作者: Y yeliang2258 提交者: GitHub

Add tests for PaddleInference Pass (#37676)

* add test for conv_elementwise_add2_act_fuse_pass and conv_elementwise_add_act_fuse_pass

* Add conv_eltwiseadd_bn_fuse_pass test and fix test_conv_elementwise_addX_act_fuse_pass

* add tests for conv_act_mkldnn_fuse_pass

* add test for conv_bias_mkldnn_fuse_pass

* update code

* add conv_act_mkldnn_fuse_pass for relu, relu6, swish, leaky_relu

* update test

* update

* update bug

* update

* update pattern_detector

* fix test_conv_eltwiseadd_bn_fuse_pass

* add diff display notest;test=windows_ci_inference

* fix

* remove test_conv_act_mkldnn_fuse_pass.py

* ifix
上级 e02537f9
...@@ -173,7 +173,7 @@ if(NOT WIN32) ...@@ -173,7 +173,7 @@ if(NOT WIN32)
endif() endif()
if (WITH_MKLDNN) if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_bias_mkldnn_fuse_pass_cc SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util)
......
...@@ -807,6 +807,7 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, ...@@ -807,6 +807,7 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
// Bias // Bias
eltwise_y_in_var = pattern->NewNode(eltwise_y_in_repr()) eltwise_y_in_var = pattern->NewNode(eltwise_y_in_repr())
->assert_is_op_input("elementwise_add", "Y") ->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var()
->AsInput(); ->AsInput();
eltwise_out_var = pattern->NewNode(eltwise_out_repr()) eltwise_out_var = pattern->NewNode(eltwise_out_repr())
->AsIntermediate() ->AsIntermediate()
......
...@@ -201,6 +201,9 @@ Conv2DSwishFusePass::Conv2DSwishFusePass() { ...@@ -201,6 +201,9 @@ Conv2DSwishFusePass::Conv2DSwishFusePass() {
.End() .End()
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End()
.AddAttr("beta")
.IsType<float>()
.End(); .End();
} }
Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
......
...@@ -239,6 +239,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -239,6 +239,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
auto input_names = conv->Op()->InputNames(); auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") != bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end(); input_names.end();
if (has_bias && conv->Op()->Input("Bias").size() > 0) { if (has_bias && conv->Op()->Input("Bias").size() > 0) {
auto conv_bias_names = conv->Op()->Input("Bias"); auto conv_bias_names = conv->Op()->Input("Bias");
// add eltwise bias to existing conv bias // add eltwise bias to existing conv bias
......
...@@ -80,7 +80,9 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) ...@@ -80,7 +80,9 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU)
set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_simplify_with_basic_ops_pass_autoscan PROPERTIES TIMEOUT 60) set_tests_properties(test_simplify_with_basic_ops_pass_autoscan PROPERTIES TIMEOUT 60)
set_tests_properties(test_adaptive_pool2d_convert_global_pass_autoscan PROPERTIES TIMEOUT 60) set_tests_properties(test_adaptive_pool2d_convert_global_pass_autoscan PROPERTIES TIMEOUT 60)
set_tests_properties(test_conv_eltwiseadd_affine_channel_fuse_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_elementwise_add2_act_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_elementwise_add_act_fuse_pass PROPERTIES TIMEOUT 120)
endif() endif()
if (WITH_MKLDNN) if (WITH_MKLDNN)
......
...@@ -144,10 +144,12 @@ class AutoScanTest(unittest.TestCase): ...@@ -144,10 +144,12 @@ class AutoScanTest(unittest.TestCase):
baseline[key].shape == arr.shape, baseline[key].shape == arr.shape,
"The output shapes are not equal, the baseline shape is " + "The output shapes are not equal, the baseline shape is " +
str(baseline[key].shape) + ', but got ' + str(arr.shape)) str(baseline[key].shape) + ', but got ' + str(arr.shape))
diff = abs(baseline[key] - arr)
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
baseline[key], arr, atol=atol, rtol=rtol), baseline[key], arr, atol=atol, rtol=rtol),
"Output has diff. ") "Output has diff, Maximum absolute error: {}".format(
np.amax(diff)))
@abc.abstractmethod @abc.abstractmethod
def run_test(self, quant=False): def run_test(self, quant=False):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestConvBiasMkldnnFusePass(PassAutoScanTest):
"""
x_var f_var(persistable)
\ /
conv2d
|
conv2d_var bias_var(persistable)
\ /
elementwise_add
|
elementwise_add_var
"""
def sample_predictor_configs(self, program_config):
# MKLDNN
config = self.create_inference_config(use_gpu=False)
config.enable_mkldnn()
yield config, ["conv2d"], (1e-4, 1e-5)
def is_program_valid(self, prog_config):
paddings = prog_config.ops[0].attrs["paddings"]
strides = prog_config.ops[0].attrs["strides"]
groups = prog_config.ops[0].attrs["groups"]
padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"]
dilations = prog_config.ops[0].attrs["dilations"]
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID":
if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
return False
if padding_algorithm == "EXPLICIT":
if ((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
return False
if data_format == "NCHW":
if input_shape[1] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
else:
if input_shape[3] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
return True
def sample_program_config(self, draw):
# 1. Generate shape of input:X of conv2d
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=1, max_value=10))
# 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
# 3. Generate legal shape of input:Y of conv2d
f_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=7), min_size=4, max_size=4))
if data_format == "NCHW":
f_shape[1] = x_shape[1]
else:
f_shape[1] = x_shape[3]
# 4. Generate legal attr:strides of conv2d
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
# 5. Generate legal attr:padding_algorithm of conv2d
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
# 6. Generate legal attr:padding of conv2d
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=4, max_size=4))
# 7. Generate legal attr:groups of conv2d
groups = draw(st.integers(min_value=1, max_value=3))
# 8. Generate legal attr:dilations of conv2d
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
# 9. Generate legal shape of input:bias of elementwise_add
bias_shape = [f_shape[0]]
# 10. Generate legal shape of attr:axis of elementwise_add
axis = 1
if data_format == "NCHW":
axis = 1
else:
axis = 3
# 11. Generate legal shape of input:bias of conv2d
conv_bias_shape = []
inputs = dict()
weights = dict()
use_mkldnn = None
if draw(st.booleans()):
conv_bias_shape = [f_shape[0]]
inputs = {
"Input": ["input_x"],
"Filter": ["filter"],
"Bias": ["conv_bias"],
}
weights = {
"filter": TensorConfig(shape=f_shape),
"bias": TensorConfig(shape=bias_shape),
"conv_bias": TensorConfig(shape=conv_bias_shape)
}
use_mkldnn = True
else:
inputs = {
"Input": ["input_x"],
"Filter": ["filter"],
}
weights = {
"filter": TensorConfig(shape=f_shape),
"bias": TensorConfig(shape=bias_shape)
}
use_mkldnn = False
conv2d_op = OpConfig(
"conv2d",
inputs=inputs,
outputs={"Output": ["conv2d_out"]},
strides=strides,
padding_algorithm=padding_algorithm,
paddings=padding,
groups=groups,
dilations=dilations,
data_format=data_format,
use_mkldnn=use_mkldnn)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["conv2d_out"],
"Y": ["bias"]},
outputs={"Out": ["add_out"]},
axis=axis)
ops = [conv2d_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights=weights,
inputs={"input_x": TensorConfig(shape=x_shape)},
outputs=ops[-1].outputs["Out"])
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=350,
passes=["conv_bias_mkldnn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,44 +12,223 @@ ...@@ -12,44 +12,223 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from auto_scan_test import PassAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import unittest
import numpy as np import numpy as np
from inference_pass_test import InferencePassTest import paddle.inference as paddle_infer
import paddle.fluid as fluid from functools import partial
import paddle.fluid.core as core from typing import Optional, List, Callable, Dict, Any, Set
from paddle.fluid.core import PassVersionChecker import unittest
from paddle.fluid.core import AnalysisConfig
"""Test for fusion of conv, elementwise_add and 2 act.""" import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class ConvElementwiseAdd2ActFusePassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program): class TestConvElementwiseAdd2ActPass(PassAutoScanTest):
data = fluid.data( """
name="data", shape=[-1, 3, 100, 100], dtype="float32") x_var f_var(persistable)
add_y2 = fluid.data( \ /
name="add_y2", shape=[1, 3, 98, 98], dtype="float32") conv2d
conv_out = fluid.layers.conv2d( |
input=data, num_filters=3, filter_size=3, bias_attr=None) conv2d_var y_var(persistable)
add1_out = fluid.layers.elementwise_add( \ /
add_y2, conv_out, act="relu") elementwise_add
|
self.feeds = { x1_var elementwise_add_out_var
"data": np.random.random((1, 3, 100, 100)).astype("float32"), \ /
"add_y2": np.random.random((1, 3, 98, 98)).astype("float32") elementwise_add
} |
self.fetch_list = [add1_out] act
self.enable_mkldnn = False |
act_var
def test_check_output(self): """
if core.is_compiled_with_cuda():
use_gpu = True def sample_predictor_configs(self, program_config):
self.check_output_with_option(use_gpu) # for gpu
self.assertTrue( config = self.create_inference_config(use_gpu=True)
PassVersionChecker.IsCompatible( yield config, ["conv2d_fusion"], (1e-4, 1e-5)
'conv_elementwise_add2_act_fuse_pass'))
def is_program_valid(self, prog_config):
paddings = prog_config.ops[0].attrs["paddings"]
strides = prog_config.ops[0].attrs["strides"]
groups = prog_config.ops[0].attrs["groups"]
padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"]
dilations = prog_config.ops[0].attrs["dilations"]
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID":
if int(((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1)) <= 0 or \
int(((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1)) <= 0:
return False
if padding_algorithm == "EXPLICIT":
if int(((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1)) <= 0 or \
int(((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1)) <= 0:
return False
if padding_algorithm == "SAME":
if int((input_shape[2] + strides[0] - 1) / strides[0]) <= 0 or int(
(input_shape[3] + strides[1] - 1) / strides[1]) <= 0:
return False
if data_format == "NCHW":
if input_shape[1] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
else:
if input_shape[3] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
return True
def sample_program_config(self, draw):
is_not_valid = True
program_config = None
while is_not_valid:
# 1. Generate shape of input:X of conv2d
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=100),
min_size=4,
max_size=4))
x_shape[1] = draw(st.integers(min_value=1, max_value=10))
# 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
# 3. Generate legal shape of input:Y of conv2d
f_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=7),
min_size=4,
max_size=4))
if data_format == "NCHW":
f_shape[1] = x_shape[1]
else:
f_shape[1] = x_shape[3]
# 4. Generate legal attr:strides of conv2d
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=2,
max_size=2))
# 5. Generate legal attr:padding_algorithm of conv2d
padding_algorithm = draw(
st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
# 6. Generate legal attr:padding of conv2d
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=4,
max_size=4))
# 7. Generate legal attr:groups of conv2d
groups = draw(st.integers(min_value=1, max_value=3))
# 8. Generate legal attr:dilations of conv2d
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=2,
max_size=2))
# 9. Generate legal elemntwise_add: X of conv2d
bias_2_dict = dict()
bias_2_dict[1] = [x_shape[0], f_shape[0], \
int(((x_shape[2] + padding[0] + padding[1] - (dilations[0] * (f_shape[2] - 1) + 1)) / strides[0] + 1)), \
int(((x_shape[3] + padding[2] + padding[3] - (dilations[1] * (f_shape[3] - 1) + 1)) / strides[1] + 1))]
bias_2_dict[2] = [x_shape[0], f_shape[0], \
int((x_shape[2] + strides[0] - 1) / strides[0]), \
int((x_shape[3] + strides[1] - 1) / strides[1])]
bias_2_dict[3] = [x_shape[0], f_shape[0], \
int(((x_shape[2] - (dilations[0] * (f_shape[2] - 1) + 1)) / strides[0] + 1)), \
int(((x_shape[3] - (dilations[1] * (f_shape[3] - 1) + 1)) / strides[1] + 1))]
bias_index = 1
if padding_algorithm == "SAME":
bias_index = 2
if padding_algorithm == "VALID":
bias_index = 3
bias_2_shape = bias_2_dict[bias_index]
if np.sum(np.array(bias_2_shape) <= 0) == 0:
is_not_valid = False
else:
continue
# 10. Generate legal shape of input:bias of elementwise_add
bias_shape = [f_shape[0]]
# 11. Generate legal attr:axis of elementwise_add_1
axis_1 = 1
# 12. Generate legal attr:axis of elementwise_add_2
axis_2 = -1
conv2d_op = OpConfig(
"conv2d",
inputs={"Input": ["input_x"],
"Filter": ["filter"]},
outputs={"Output": ["conv2d_out"]},
strides=strides,
padding_algorithm=padding_algorithm,
paddings=padding,
groups=groups,
dilations=dilations,
data_format=data_format)
add_1_op = OpConfig(
"elementwise_add",
inputs={"X": ["conv2d_out"],
"Y": ["bias_1"]},
outputs={"Out": ["add_1_out"]},
axis=axis_1)
add_2_op = OpConfig(
"elementwise_add",
inputs={"X": ["bias_2"],
"Y": ["add_1_out"]},
outputs={"Out": ["add_out"]},
axis=axis_2)
relu_op = OpConfig(
"relu",
inputs={"X": ["add_out"]},
outputs={"Out": ["relu_out"]})
ops = [conv2d_op, add_1_op, add_2_op, relu_op]
program_config = ProgramConfig(
ops=ops,
weights={
"filter": TensorConfig(shape=f_shape),
"bias_1": TensorConfig(shape=bias_shape),
},
inputs={
"input_x": TensorConfig(shape=x_shape),
"bias_2": TensorConfig(shape=bias_2_shape)
},
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=300,
passes=["conv_elementwise_add2_act_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,46 +12,177 @@ ...@@ -12,46 +12,177 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import unittest
import numpy as np import numpy as np
from inference_pass_test import InferencePassTest import paddle.inference as paddle_infer
import paddle.fluid as fluid from functools import partial
import paddle.fluid.core as core from typing import Optional, List, Callable, Dict, Any, Set
from paddle.fluid.core import PassVersionChecker import unittest
from paddle.fluid.core import AnalysisConfig
"""Test for fusion of conv, elementwise_add and act.""" import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class ConvElementwiseAddActFusePassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program): class TestConvElementwiseAddActPass(PassAutoScanTest):
data = fluid.data( """
name="data", shape=[-1, 3, 100, 100], dtype="float32") x_var f_var(persistable)
param_attr = fluid.ParamAttr( \ /
initializer=fluid.initializer.Xavier(uniform=False), conv2d
learning_rate=0.001) |
conv_out = fluid.layers.conv2d( conv2d_var y_var(persistable)
input=data, \ /
num_filters=3, elementwise_add
filter_size=3, |
bias_attr=param_attr, elementwise_add_var
act="relu") |
act
self.feeds = { |
"data": np.random.random((1, 3, 100, 100)).astype("float32") act_var
} """
self.fetch_list = [conv_out]
self.enable_mkldnn = False def sample_predictor_configs(self, program_config):
# for gpu
def test_check_output(self): config = self.create_inference_config(use_gpu=True)
if core.is_compiled_with_cuda(): yield config, ["conv2d_fusion"], (1e-4, 1e-5)
use_gpu = True
self.check_output_with_option(use_gpu) def is_program_valid(self, prog_config):
self.assertTrue( paddings = prog_config.ops[0].attrs["paddings"]
PassVersionChecker.IsCompatible( strides = prog_config.ops[0].attrs["strides"]
'conv_elementwise_add_act_fuse_pass')) groups = prog_config.ops[0].attrs["groups"]
padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"]
dilations = prog_config.ops[0].attrs["dilations"]
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID":
if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
return False
if padding_algorithm == "EXPLICIT":
if ((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
return False
if data_format == "NCHW":
if input_shape[1] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
else:
if input_shape[3] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
return True
def sample_program_config(self, draw):
# 1. Generate shape of input:X of conv2d
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=1, max_value=10))
# 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
# 3. Generate legal shape of input:Y of conv2d
f_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=7), min_size=4, max_size=4))
if data_format == "NCHW":
f_shape[1] = x_shape[1]
else:
f_shape[1] = x_shape[3]
# 4. Generate legal attr:strides of conv2d
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
# 5. Generate legal attr:padding_algorithm of conv2d
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
# 6. Generate legal attr:padding of conv2d
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=4, max_size=4))
# 7. Generate legal attr:groups of conv2d
groups = draw(st.integers(min_value=1, max_value=3))
# 8. Generate legal attr:dilations of conv2d
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
# 9. Generate legal input:ResidualData of conv2d
res_shape = []
if draw(st.booleans()):
res_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=100),
min_size=4,
max_size=4))
# 10. Generate legal shape of input:bias of elementwise_add
bias_shape = [f_shape[0]]
# 11. Generate legal attr:axis of elementwise_add
axis = 1
conv2d_op = OpConfig(
"conv2d",
inputs={
"Input": ["input_x"],
"Filter": ["filter"],
"ResidualData": ["residualdata"]
},
outputs={"Output": ["conv2d_out"]},
strides=strides,
padding_algorithm=padding_algorithm,
paddings=padding,
groups=groups,
dilations=dilations,
data_format=data_format)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["conv2d_out"],
"Y": ["bias"]},
outputs={"Out": ["add_out"]},
axis=axis)
relu_op = OpConfig(
"relu", inputs={"X": ["add_out"]}, outputs={"Out": ["relu_out"]})
ops = [conv2d_op, add_op, relu_op]
program_config = ProgramConfig(
ops=ops,
weights={
"filter": TensorConfig(shape=f_shape),
"bias": TensorConfig(shape=bias_shape),
},
inputs={
"input_x": TensorConfig(shape=x_shape),
"residualdata": TensorConfig(shape=res_shape)
},
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=400,
passes=["conv_elementwise_add_act_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestConvEltwiseaddBnFusePass(PassAutoScanTest):
"""
x_var f_var(persistable)
\ /
conv2d
|
conv2d_var bias_var(persistable)
\ /
elementwise_add
|
elementwise_add_var Scale(persistable) Bias(persistable) Mean(persistable) Variance(persistable)
|
batch_norm
|
Y MeanOut VarianceOut SavedMeanSavedVariance
"""
def sample_predictor_configs(self, program_config):
# cpu
config = self.create_inference_config(use_gpu=False)
yield config, ["conv2d", "elementwise_add"], (1e-4, 1e-5)
# MKLDNN
config = self.create_inference_config(use_gpu=False)
config.enable_mkldnn()
yield config, ["conv2d", "elementwise_add"], (1e-4, 1e-5)
# for gpu
config = self.create_inference_config(use_gpu=True)
yield config, ["conv2d", "elementwise_add"], (1e-4, 1e-5)
def is_program_valid(self, prog_config):
paddings = prog_config.ops[0].attrs["paddings"]
strides = prog_config.ops[0].attrs["strides"]
groups = prog_config.ops[0].attrs["groups"]
padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"]
dilations = prog_config.ops[0].attrs["dilations"]
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID":
if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
return False
if padding_algorithm == "EXPLICIT":
if ((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
return False
if data_format == "NCHW":
if input_shape[1] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
else:
if input_shape[3] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
bn_scale = np.array(prog_config.weights["scale_in"].data)
bn_bias = np.array(prog_config.weights["bias_in"].data)
bn_mean = np.array(prog_config.weights["mean_in"].data)
bn_variance = np.array(prog_config.weights["variance_in"].data)
epsilon = np.array(prog_config.ops[-1].attrs["epsilon"])
bn_variance = bn_variance + epsilon
if np.isnan(bn_variance).any():
return False
bn_variance = np.sqrt(bn_variance)
if np.sum(bn_variance == 0.0) > 0:
return False
bn_variance = bn_scale / bn_variance
if np.isnan(bn_variance).any():
return False
return True
def sample_program_config(self, draw):
# 1. Generate shape of input:X of conv2d
x_shape = draw(
st.lists(
st.integers(
min_value=10, max_value=100),
min_size=4,
max_size=4))
x_shape[1] = draw(st.integers(min_value=1, max_value=10))
# 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
# 2. Generate legal shape of input:Y of conv2d
f_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=7), min_size=4, max_size=4))
if data_format == "NCHW":
f_shape[1] = x_shape[1]
else:
f_shape[1] = x_shape[3]
# 3. Generate legal attr:strides of conv2d
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
# 4. Generate legal attr:padding_algorithm of conv2d
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
# 5. Generate legal attr:padding of conv2d
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=4, max_size=4))
# 6. Generate legal attr:groups of conv2d
groups = draw(st.integers(min_value=1, max_value=3))
# 7. Generate legal attr:dilations of conv2d
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
# 9. Generate legal input:ResidualData of conv2d
res_shape = []
if draw(st.booleans()):
res_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=100),
min_size=4,
max_size=4))
# 10. Generate legal shape of input:bias of elementwise_add
bias_shape = [f_shape[0]]
# 11. Generate legal attr:axis of elementwise_add
axis = 1
# 12. Generate legal input:Scale of batch_norm
bn_scale_shape = [f_shape[0]]
# 13. Generate legal input:Bias of batch_norm
bn_bias_shape = [f_shape[0]]
# 14. Generate legal input:Mean of batch_norm
bn_mean_shape = [f_shape[0]]
# 15. Generate legal input:Variance of batch_norm
bn_variance_shape = [f_shape[0]]
# 16. Generate legal attr:epsilon of batch_norm
epsilon = draw(st.floats(min_value=0.00001, max_value=0.001))
def generate_batch_variance():
return (0.1 + (1.0 - 0.1) * np.random.random(bn_variance_shape)
).astype(np.float32)
conv2d_op = OpConfig(
"conv2d",
inputs={
"Input": ["input_x"],
"Filter": ["filter"],
"ResidualData": ["residualdata"]
},
outputs={"Output": ["conv2d_out"]},
strides=strides,
padding_algorithm=padding_algorithm,
paddings=padding,
groups=groups,
dilations=dilations,
data_format=data_format)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["conv2d_out"],
"Y": ["bias"]},
outputs={"Out": ["add_out"]},
axis=axis)
bn_op = OpConfig(
"batch_norm",
inputs={
"X": ["add_out"],
"Scale": ["scale_in"],
"Bias": ["bias_in"],
"Mean": ["mean_in"],
"Variance": ["variance_in"]
},
outputs={
"Y": ["y_out"],
"MeanOut": ["mean_in"],
"VarianceOut": ["variance_in"],
"SavedMean": ["SavedMean_out"],
"SavedVariance": ["SavedVariance_out"],
"ReserveSpace": ["ReserveSpace_out"]
},
epsilon=epsilon,
is_test=True,
trainable_statistics=False,
data_layout=data_format)
ops = [conv2d_op, add_op, bn_op]
# 17. if the output of bias is more than one
if draw(st.booleans()):
outputs = ops[-1].outputs["Y"]
else:
outputs = ops[-1].outputs["Y"] + ["bias"]
program_config = ProgramConfig(
ops=ops,
weights={
"filter": TensorConfig(shape=f_shape),
"bias": TensorConfig(shape=bias_shape),
"scale_in": TensorConfig(shape=bn_scale_shape),
"bias_in": TensorConfig(shape=bn_bias_shape),
"mean_in": TensorConfig(shape=bn_mean_shape),
"variance_in": TensorConfig(data_gen=generate_batch_variance),
},
inputs={
"input_x": TensorConfig(shape=x_shape),
"residualdata": TensorConfig(shape=res_shape)
},
outputs=outputs)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=300,
passes=["conv_eltwiseadd_bn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
...@@ -431,7 +431,7 @@ HIGH_PARALLEL_JOB_NEW = [ ...@@ -431,7 +431,7 @@ HIGH_PARALLEL_JOB_NEW = [
'test_memory_usage', 'test_memory_usage',
'test_sysconfig', 'test_sysconfig',
'reader_test', 'reader_test',
'test_conv_bias_mkldnn_fuse_pass', 'test_conv_bias_mkldnn_fuse_pass_cc',
'math_function_test', 'math_function_test',
'beam_search_decode_op_test', 'beam_search_decode_op_test',
'save_quant2_model_resnet50', 'save_quant2_model_resnet50',
...@@ -1469,7 +1469,7 @@ CPU_PARALLEL_JOB = [ ...@@ -1469,7 +1469,7 @@ CPU_PARALLEL_JOB = [
'test_cpu_bfloat16_placement_pass', 'test_cpu_bfloat16_placement_pass',
'test_cpu_bfloat16_pass', 'test_cpu_bfloat16_pass',
'test_conv_concat_relu_mkldnn_fuse_pass', 'test_conv_concat_relu_mkldnn_fuse_pass',
'test_conv_bias_mkldnn_fuse_pass', 'test_conv_bias_mkldnn_fuse_pass_cc',
'test_conv_batch_norm_mkldnn_fuse_pass', 'test_conv_batch_norm_mkldnn_fuse_pass',
'test_conv3d_transpose_layer', 'test_conv3d_transpose_layer',
'test_conv3d_mkldnn_op', 'test_conv3d_mkldnn_op',
......
文件模式从 100644 更改为 100755
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册