未验证 提交 6c4621f1 编写于 作者: J Jason 提交者: GitHub

Refine autoscan pass (#37363)

上级 e87545ce
...@@ -72,6 +72,7 @@ set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) ...@@ -72,6 +72,7 @@ set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30) set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30)
set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 120)
if (WITH_MKLDNN) if (WITH_MKLDNN)
set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300)
......
...@@ -31,7 +31,8 @@ from typing import Optional, List, Callable, Dict, Any, Set ...@@ -31,7 +31,8 @@ from typing import Optional, List, Callable, Dict, Any, Set
from program_config import TensorConfig, OpConfig, ProgramConfig, create_fake_model, create_quant_model from program_config import TensorConfig, OpConfig, ProgramConfig, create_fake_model, create_quant_model
import hypothesis import hypothesis
from hypothesis import given, settings, seed, example, assume from hypothesis import given, settings, seed, reproduce_failure
import hypothesis.strategies as st
logging.basicConfig(level=logging.INFO, format="%(message)s") logging.basicConfig(level=logging.INFO, format="%(message)s")
...@@ -78,6 +79,11 @@ class AutoScanTest(unittest.TestCase): ...@@ -78,6 +79,11 @@ class AutoScanTest(unittest.TestCase):
abs_dir = os.path.abspath(os.path.dirname(__file__)) abs_dir = os.path.abspath(os.path.dirname(__file__))
self.cache_dir = os.path.join(abs_dir, self.cache_dir = os.path.join(abs_dir,
str(self.__module__) + '_cache_dir') str(self.__module__) + '_cache_dir')
self.available_passes_in_framework = set()
self.num_ran_programs = 0
self.num_invalid_programs = 0
self.num_skipped_tests = 0
self.num_predictor_kinds = 0
@abc.abstractmethod @abc.abstractmethod
def sample_program_configs(self): def sample_program_configs(self):
...@@ -99,9 +105,8 @@ class AutoScanTest(unittest.TestCase): ...@@ -99,9 +105,8 @@ class AutoScanTest(unittest.TestCase):
note: str): note: str):
self.skip_cases.append((teller, reason, note)) self.skip_cases.append((teller, reason, note))
@abc.abstractmethod
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
raise NotImplementedError return True
def run_test_config(self, model, params, prog_config, pred_config, def run_test_config(self, model, params, prog_config, pred_config,
feed_data) -> Dict[str, np.ndarray]: feed_data) -> Dict[str, np.ndarray]:
...@@ -110,6 +115,8 @@ class AutoScanTest(unittest.TestCase): ...@@ -110,6 +115,8 @@ class AutoScanTest(unittest.TestCase):
''' '''
pred_config.set_model_buffer(model, len(model), params, len(params)) pred_config.set_model_buffer(model, len(model), params, len(params))
predictor = paddle_infer.create_predictor(pred_config) predictor = paddle_infer.create_predictor(pred_config)
self.available_passes_in_framework = self.available_passes_in_framework | set(
pred_config.pass_builder().all_passes())
for name, _ in prog_config.inputs.items(): for name, _ in prog_config.inputs.items():
input_tensor = predictor.get_input_handle(name) input_tensor = predictor.get_input_handle(name)
...@@ -277,39 +284,118 @@ class PassAutoScanTest(AutoScanTest): ...@@ -277,39 +284,118 @@ class PassAutoScanTest(AutoScanTest):
def check_op_version(self): def check_op_version(self):
status = True status = True
for pass_name in self.passes: for pass_name in self.passes:
if pass_name not in self.available_passes_in_framework:
continue
if not PassVersionChecker.IsCompatible(pass_name): if not PassVersionChecker.IsCompatible(pass_name):
self.fail_log('{} version check failed.'.format(pass_name)) self.fail_log('{} version check failed.'.format(pass_name))
status = False status = False
return status return status
def assert_op_size(self, fusion_before_num, fusion_after_num, origin_model): def add_skip_pass_case(self):
return
def assert_op_list(self, op_list_after_fusion):
if not self.passes: if not self.passes:
raise ValueError( raise ValueError(
'In PassAutoScan you should give a valid pass name.') "In PassAutoScan you should give a valid pass name.")
last_passed_program = os.path.join(self.cache_dir, last_passed_program = os.path.join(self.cache_dir,
self.passes[-1] + '.pdmodel') self.passes[-1] + ".pdmodel")
if not os.path.exists(last_passed_program):
raise ValueError(
"Cannot find file {}, please make sure that your pass name is correct".
format(last_passed_program))
model_bytes = paddle.static.load_from_file(last_passed_program) model_bytes = paddle.static.load_from_file(last_passed_program)
pg = paddle.static.deserialize_program(model_bytes) pg = paddle.static.deserialize_program(model_bytes)
main_block = pg.desc.block(0) main_block = pg.desc.block(0)
after_op_size = main_block.op_size() after_op_list = list()
pg = paddle.static.deserialize_program(origin_model) for i in range(main_block.op_size()):
main_block = pg.desc.block(0) if main_block.op(i).type() in ["feed", "fetch"]:
before_op_size = main_block.op_size() continue
self.assertTrue(before_op_size == fusion_before_num, after_op_list.append(main_block.op(i).type())
'before fusion op size is {}, but got {}!'.format( self.assertTrue(
before_op_size, fusion_before_num)) op_list_after_fusion == after_op_list,
self.assertTrue(after_op_size == fusion_after_num, "Expected operator list after fusion is {}, but now it's {}".format(
'after fusion op size is {}, but got {}!'.format( op_list_after_fusion, after_op_list), )
after_op_size, fusion_after_num))
def run_test(self, quant=False, *args, **kwargs): def run_and_statis(
self,
quant=False,
max_examples=100,
reproduce=None,
min_success_num=25,
max_duration=180,
passes=None, ):
if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev":
max_examples *= 10
min_success_num *= 10
# while at ce phase, there's no limit on time
max_duration = -1
start_time = time.time()
settings.register_profile(
"ci",
max_examples=max_examples,
suppress_health_check=hypothesis.HealthCheck.all(),
deadline=None,
print_blob=True,
derandomize=True,
report_multiple_bugs=False, )
settings.load_profile("ci")
assert passes is not None, "Parameter of passes must be defined in function run_and_statis."
self.passes = passes
self.add_skip_pass_case()
def program_generator(draw):
return self.sample_program_config(draw)
def run_test(prog_config):
return self.run_test(quant=quant, prog_configs=[prog_config])
generator = st.composite(program_generator)
loop_func = given(generator())(run_test)
if reproduce is not None:
loop_func = reproduce(loop_func)
logging.info("Start to running test of {}".format(type(self)))
loop_func()
logging.info(
"===================Statistical Information===================")
logging.info("Number of Generated Programs: {}".format(
self.num_ran_programs + self.num_invalid_programs))
logging.info("Number of Invalid Programs: {}".format(
self.num_invalid_programs))
logging.info("Number of Ran Programs: {}".format(self.num_ran_programs))
logging.info("Number of Skipped Tests: {}".format(
self.num_skipped_tests))
successful_ran_programs = int(self.num_ran_programs -
self.num_skipped_tests /
self.num_predictor_kinds)
logging.info(
"Number of successfully ran programs approximately equal to {}".
format(successful_ran_programs))
if successful_ran_programs < min_success_num:
logging.warning(
"satisfied_programs = ran_programs - num_skipped_tests / num_predictor_kinds"
)
logging.error(
"At least {} programs need to ran successfully, but now only about {} programs satisfied.".
format(min_success_num, successful_ran_programs))
assert False
used_time = time.time() - start_time
if max_duration > 0 and used_time > max_duration:
logging.error(
"The duration exceeds {} seconds, if this is neccessary, try to set a larger number for parameter `max_duration`.".
format(max_duration))
assert False
def run_test(self, quant=False, prog_configs=None):
status = True status = True
for prog_config in self.sample_program_configs(*args, **kwargs): for prog_config in prog_configs:
# if program is invalid, we should skip that cases. # if program is invalid, we should skip that cases.
if not self.is_program_valid(prog_config): if not self.is_program_valid(prog_config):
self.num_invalid_programs += 1
continue continue
self.num_ran_programs += 1
model, params = create_fake_model(prog_config) model, params = create_fake_model(prog_config)
if quant: if quant:
model, params = create_quant_model(model, params) model, params = create_quant_model(model, params)
...@@ -330,13 +416,16 @@ class PassAutoScanTest(AutoScanTest): ...@@ -330,13 +416,16 @@ class PassAutoScanTest(AutoScanTest):
feed_data)) feed_data))
self.success_log('RUN_CPU_BASELINE done') self.success_log('RUN_CPU_BASELINE done')
for pred_config, nodes_num, ( self.num_predictor_kinds = 0
for pred_config, op_list, (
atol, rtol) in self.sample_predictor_configs(prog_config): atol, rtol) in self.sample_predictor_configs(prog_config):
self.num_predictor_kinds += 1
# skip info # skip info
skip_flag = False skip_flag = False
for skip_info in self.skip_cases: for skip_info in self.skip_cases:
if skip_info[0](prog_config, pred_config): if skip_info[0](prog_config, pred_config):
skip_flag = True skip_flag = True
self.num_skipped_tests += 1
if skip_info[1] == SkipReasons.PASS_ACCURACY_ERROR: if skip_info[1] == SkipReasons.PASS_ACCURACY_ERROR:
self.skip_log("[PASS_ACCURACY_ERROR] " + skip_info[ self.skip_log("[PASS_ACCURACY_ERROR] " + skip_info[
2] + ' ' + ' vs ' + self.inference_config_str( 2] + ' ' + ' vs ' + self.inference_config_str(
...@@ -357,7 +446,7 @@ class PassAutoScanTest(AutoScanTest): ...@@ -357,7 +446,7 @@ class PassAutoScanTest(AutoScanTest):
self.assert_tensors_near(atol, rtol, results[-1], self.assert_tensors_near(atol, rtol, results[-1],
results[0]) results[0])
if not skip_flag: if not skip_flag:
self.assert_op_size(nodes_num[0], nodes_num[1], model) self.assert_op_list(op_list)
except Exception as e: except Exception as e:
self.fail_log( self.fail_log(
......
...@@ -34,17 +34,24 @@ class TensorConfig: ...@@ -34,17 +34,24 @@ class TensorConfig:
def __init__(self, def __init__(self,
lod: Optional[List[List[int]]]=None, lod: Optional[List[List[int]]]=None,
data_gen: Optional[Callable[..., np.array]]=None): data_gen: Optional[Callable[..., np.array]]=None,
shape: Optional[List[List[int]]]=None):
''' '''
shape: The shape of the tensor. shape: The shape of the tensor.
dtype: The data type of the tensor. dtype: The data type of the tensor.
data: The value of WeightVar. for input, it should be None data: The value of WeightVar. for input, it should be None
''' '''
self.lod = lod self.lod = lod
self.data_gen = data_gen if data_gen is not None:
self.data = data_gen() self.data_gen = data_gen
self.dtype = data_gen().dtype self.data = data_gen()
self.shape = data_gen().shape self.dtype = data_gen().dtype
self.shape = data_gen().shape
else:
assert shape is not None, "While data_gen is not defined, shape must not be None"
self.data = np.random.normal(0.0, 1.0, shape).astype(np.float32)
self.shape = shape
self.dtype = self.data.dtype
def __repr__(self): def __repr__(self):
return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype}) return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype})
...@@ -57,11 +64,15 @@ class OpConfig: ...@@ -57,11 +64,15 @@ class OpConfig:
type: str, type: str,
inputs: Dict[str, List[str]], inputs: Dict[str, List[str]],
outputs: Dict[str, List[str]], outputs: Dict[str, List[str]],
attrs: Dict[str, Any]): attrs: Dict[str, Any]=None,
**kwargs):
self.type = type self.type = type
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.attrs = attrs self.attrs = attrs
if self.attrs is None:
self.attrs = dict()
self.attrs.update(kwargs)
def __repr__(self): def __repr__(self):
log_str = self.type log_str = self.type
......
...@@ -71,7 +71,19 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -71,7 +71,19 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
return True return True
def sample_program_configs(self, *args, **kwargs): def sample_program_config(self, draw):
is_sparse = draw(st.booleans())
is_distributed = draw(st.booleans())
padding_idx = draw(st.integers())
axis = draw(st.integers(min_value=-4, max_value=4))
op_type = draw(st.sampled_from(['lookup_table', 'lookup_table_v2']))
epsilon = draw(st.floats(min_value=0, max_value=0.001))
# begin_norm_axis has to be 2
begin_norm_axis = 2
batch_size = draw(st.integers(min_value=1, max_value=4))
input_dim = draw(st.sampled_from([32, 64]))
weight_size = draw(st.sampled_from([[64, 64], [64, 32]]))
def generate_input(attrs): def generate_input(attrs):
if attrs[0]['op_type'] == 'lookup_table': if attrs[0]['op_type'] == 'lookup_table':
return np.random.randint( return np.random.randint(
...@@ -101,19 +113,19 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -101,19 +113,19 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
np.float32) np.float32)
attrs = [{ attrs = [{
'is_sparse': kwargs['is_sparse'], 'is_sparse': is_sparse,
'is_distributed': kwargs['is_distributed'], 'is_distributed': is_distributed,
'padding_idx': kwargs['padding_idx'], 'padding_idx': padding_idx,
'op_type': kwargs['op_type'] 'op_type': op_type
}, { }, {
'axis': kwargs['axis'] 'axis': axis
}, { }, {
'begin_norm_axis': kwargs['begin_norm_axis'], 'begin_norm_axis': begin_norm_axis,
'epsilon': kwargs['epsilon'] 'epsilon': epsilon
}, { }, {
'batch_size': kwargs['batch_size'], 'batch_size': batch_size,
'input_dim': kwargs['input_dim'], 'input_dim': input_dim,
'weight_size': kwargs['weight_size'] 'weight_size': weight_size
}] }]
emb_op1 = OpConfig( emb_op1 = OpConfig(
...@@ -203,13 +215,12 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -203,13 +215,12 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
}, },
outputs=["layer_norm_output1"]) outputs=["layer_norm_output1"])
yield program_config return program_config
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
# only used in gpu passes and trt passes. # only used in gpu passes and trt passes.
config = self.create_inference_config( config = self.create_inference_config(use_gpu=True)
passes=['embedding_eltwise_layernorm_fuse_pass'], use_gpu=True) yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5)
yield config, (10, 5), (1e-5, 1e-5)
# trt static_shape # trt static_shape
config = self.create_trt_inference_config() config = self.create_trt_inference_config()
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
...@@ -219,7 +230,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -219,7 +230,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
precision_mode=paddle_infer.PrecisionType.Float32, precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False, use_static=False,
use_calib_mode=False) use_calib_mode=False)
yield config, (10, 5), (1e-5, 1e-5) yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5)
# trt dynamic_shape # trt dynamic_shape
config = self.create_trt_inference_config() config = self.create_trt_inference_config()
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
...@@ -257,7 +268,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -257,7 +268,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
"input_data2": [2, 128], "input_data2": [2, 128],
"input_data3": [2, 128] "input_data3": [2, 128]
}) })
yield config, (10, 5), (1e-5, 1e-5) yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5)
def add_skip_pass_case(self): def add_skip_pass_case(self):
def teller1(program_config, predictor_config): def teller1(program_config, predictor_config):
...@@ -272,26 +283,13 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -272,26 +283,13 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
self.add_skip_case(teller1, SkipReasons.PASS_ACCURACY_ERROR, self.add_skip_case(teller1, SkipReasons.PASS_ACCURACY_ERROR,
"The pass output has diff in a specific case.") "The pass output has diff in a specific case.")
@given( def test(self):
is_sparse=st.booleans(), # this fuse need to fix, now there's no program can ran successfully
is_distributed=st.booleans(), self.run_and_statis(
padding_idx=st.integers(), quant=False,
axis=st.integers( max_examples=50,
min_value=-4, max_value=4), passes=["embedding_eltwise_layernorm_fuse_pass"],
op_type=st.sampled_from(['lookup_table', 'lookup_table_v2']), min_success_num=0)
epsilon=st.floats(
min_value=0, max_value=0.001),
begin_norm_axis=st.integers(
min_value=-4, max_value=4),
batch_size=st.integers(
min_value=1, max_value=4),
input_dim=st.sampled_from([32, 64]),
weight_size=st.sampled_from([[64, 64], [64, 32]]))
def test(self, *args, **kwargs):
assume(kwargs['begin_norm_axis'] == 2)
self.add_skip_pass_case()
self.run_test(quant=False, *args, **kwargs)
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,42 +12,159 @@ ...@@ -12,42 +12,159 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from auto_scan_test import PassAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import unittest
import numpy as np import numpy as np
from inference_pass_test import InferencePassTest import paddle.inference as paddle_infer
import paddle.fluid as fluid from functools import partial
import paddle.fluid.core as core from typing import Optional, List, Callable, Dict, Any, Set
from paddle.fluid.core import AnalysisConfig import unittest
from paddle.fluid.core import PassVersionChecker
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
class FcFusePassTest(InferencePassTest): import hypothesis.strategies as st
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data( class TestFcFusePass(PassAutoScanTest):
name="data", shape=[-1, 128, 768], dtype="float32") """
data_y = fluid.data(name="y", shape=[-1, 128, 768], dtype="float32") x_var y_var(persistable)
fc_out1 = fluid.layers.fc(input=data, \ /
size=3072, mul bias_var(persistable)
num_flatten_dims=2, |
act="relu") mul_out_var bias_var(persistable)
fc_out2 = fluid.layers.fc(input=fc_out1, \ /
size=768, elementwise_add
num_flatten_dims=2) """
self.feeds = {"data": np.random.random((4, 128, 768)).astype("float32")} def sample_predictor_configs(self, program_config):
self.fetch_list = [fc_out2] # cpu
before_num_ops = len(program_config.ops) + 2
def test_check_output(self): config = self.create_inference_config(use_gpu=False)
use_gpu = [False] yield config, ["fc"], (1e-5, 1e-5)
if core.is_compiled_with_cuda():
use_gpu.append(True) # for gpu
for i in range(len(use_gpu)): config = self.create_inference_config(use_gpu=True)
self.check_output_with_option(use_gpu[i]) yield config, ["fc"], (1e-5, 1e-5)
self.assertTrue(PassVersionChecker.IsCompatible('fc_fuse_pass')) def add_skip_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
# shape of bias should be [1, mul_y_shape[-1]] or [mul_y_shape[-1]]
x_shape = list(program_config.inputs["mul_x"].shape)
y_shape = list(program_config.weights["mul_y"].shape)
bias_shape = program_config.weights["bias"].shape
if (bias_shape != [y_shape[-1], ] and
bias_shape != [1, y_shape[-1]]):
return True
return False
def teller2(program_config, predictor_config):
# TODO fuse has bug while axis != -1
if program_config.ops[1].attrs["axis"] != -1:
return True
return False
self.add_skip_case(
teller1,
SkipReasons.PASS_ACCURACY_ERROR,
"The pass output has diff while shape of bias is not [out_size] or [1, out_size].",
)
self.add_skip_case(
teller2,
SkipReasons.PASS_ACCURACY_ERROR,
"The pass output has diff while axis of elementwise_add is not -1.",
)
def is_program_valid(self, prog_config):
add_x_rank = prog_config.ops[0].attrs["x_num_col_dims"] + 1
add_y_rank = len(prog_config.weights["bias"].shape)
axis = prog_config.ops[1].attrs["axis"]
if add_x_rank == add_y_rank:
if axis != -1 or axis != 0:
return False
return True
def sample_program_config(self, draw):
# 1. Generate shape of input:X of mul
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=4), min_size=2, max_size=4))
# 2. Generate attr:x_num_col_dims/y_num_col_dims of mul
x_num_col_dims = draw(
st.integers(
min_value=1, max_value=len(x_shape) - 1))
y_num_col_dims = 1
# 3. Generate legal shape of input:Y of mul
y_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=2))
y_shape[0] = int(np.prod(x_shape[x_num_col_dims:]))
# 4. Generate legal attr:axis of elementwise_add
mul_out_shape = x_shape[:x_num_col_dims] + y_shape[1:]
axis = draw(st.integers(min_value=-1, max_value=x_num_col_dims))
# 5. Generate legal shape of input:Y of elementwise_add
if axis >= 0:
max_bias_rank = x_num_col_dims + 1 - axis
bias_rank = draw(st.integers(min_value=1, max_value=max_bias_rank))
bias_shape = mul_out_shape[axis:axis + bias_rank]
else:
max_bias_rank = 1
bias_rank = draw(
st.integers(
min_value=1, max_value=len(mul_out_shape)))
bias_shape = mul_out_shape[-1 * bias_rank:]
# 6. Random choose if use broadcast for elementwise_add, e.g [3, 4] -> [1, 4]
if draw(st.booleans()):
broadcast_dims = draw(st.integers(min_value=1, max_value=bias_rank))
for i in range(0, broadcast_dims):
bias_shape[i] = 1
# 7. Random choose if add a relu operator
has_relu = draw(st.booleans())
# Now we have all the decided parameters to compose a program
# shape of inputs/weights tensors: x_shape, y_shape, bias_shape...
# parameters of operators: x_num_col_dims, y_num_col_dims, axis...
# a random boolean value(has_relu) to decide if program include a relu op
# Here we will compose a program
# Still has some risks that the program is invalid or cause bug while running
# Use function `is_program_valid` to filter the invalid programs before running
# Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing
mul_op = OpConfig(
"mul",
inputs={"X": ["mul_x"],
"Y": ["mul_y"]},
outputs={"Out": ["mul_out"]},
x_num_col_dims=x_num_col_dims,
y_num_col_dims=y_num_col_dims, )
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul_out"],
"Y": ["bias"]},
outputs={"Out": ["add_out"]},
axis=axis, )
ops = [mul_op, add_op]
if has_relu:
relu_op = OpConfig(
"relu",
inputs={"X": ["add_out"]},
outputs={"Out": ["relu_out"]})
ops.append(relu_op)
program_config = ProgramConfig(
ops=ops,
weights={
"mul_y": TensorConfig(shape=y_shape),
"bias": TensorConfig(shape=bias_shape),
},
inputs={"mul_x": TensorConfig(shape=x_shape), },
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False, max_examples=300, passes=["fc_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册