未验证 提交 b781953e 编写于 作者: A arlesniak 提交者: GitHub

[oneDNN] Fix flags use test for #29080, assert condition more general (#29493)

* Flags assert condition more general, print output if pattern not found

* removed test_flags_use_mkldnn form skip list regarding #29080 descr
上级 ae3f7a71
......@@ -418,7 +418,7 @@ test_fuse_optimizer_pass^|test_generator_dataloader^|test_ir_memory_optimize_ife
test_multiprocess_dataloader_iterable_dataset_dynamic^|test_multiprocess_dataloader_iterable_dataset_static^|test_parallel_dygraph_sync_batch_norm^|test_parallel_executor_drop_scope^|^
test_parallel_executor_dry_run^|test_partial_eager_deletion_transformer^|test_prune^|test_py_reader_combination^|test_py_reader_pin_memory^|^
test_py_reader_push_pop^|test_py_reader_using_executor^|test_reader_reset^|test_update_loss_scaling_op^|test_imperative_static_runner_while^|^
test_flags_use_mkldnn^|test_optimizer_in_control_flow^|test_fuse_bn_act_pass^|^
test_optimizer_in_control_flow^|test_fuse_bn_act_pass^|^
test_fuse_bn_add_act_pass^|test_activation_mkldnn_op^|test_tsm^|test_gru_rnn_op^|test_rnn_op^|test_simple_rnn_op^|test_pass_builder^|test_lstm_cudnn_op^|test_inplace_addto_strategy^|^
test_ir_inplace_pass^|test_ir_memory_optimize_pass^|test_memory_reuse_exclude_feed_var^|test_mix_precision_all_reduce_fuse^|test_parallel_executor_pg^|test_print_op^|test_py_func_op^|^
test_weight_decay^|^
......
......@@ -48,54 +48,65 @@ class TestFlagsUseMkldnn(unittest.TestCase):
returncode = proc.returncode
assert returncode == 0
return out
return out, err
def found(self, regex, out):
return re.search(regex, out, re.MULTILINE)
def _print_when_false(self, cond, out, err):
if not cond:
print('out', out)
print('err', err)
return cond
def found(self, regex, out, err):
_found = re.search(regex, out, re.MULTILINE)
return self._print_when_false(_found, out, err)
def not_found(self, regex, out, err):
_not_found = not re.search(regex, out, re.MULTILINE)
return self._print_when_false(_not_found, out, err)
def test_flags_use_mkl_dnn_on_empty_off_empty(self):
out = self.flags_use_mkl_dnn_common({})
assert self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert self.found(self.matmul_regex, out)
out, err = self.flags_use_mkl_dnn_common({})
assert self.found(self.relu_regex, out, err)
assert self.found(self.ew_add_regex, out, err)
assert self.found(self.matmul_regex, out, err)
def test_flags_use_mkl_dnn_on(self):
env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu")}
out = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out)
assert not self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
out, err = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out, err)
assert self.not_found(self.ew_add_regex, out, err)
assert self.not_found(self.matmul_regex, out, err)
def test_flags_use_mkl_dnn_on_multiple(self):
env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu,elementwise_add")}
out = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
out, err = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out, err)
assert self.found(self.ew_add_regex, out, err)
assert self.not_found(self.matmul_regex, out, err)
def test_flags_use_mkl_dnn_off(self):
env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")}
out = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
out, err = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out, err)
assert self.found(self.ew_add_regex, out, err)
assert self.not_found(self.matmul_regex, out, err)
def test_flags_use_mkl_dnn_off_multiple(self):
env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul,relu")}
out = self.flags_use_mkl_dnn_common(env)
assert not self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
out, err = self.flags_use_mkl_dnn_common(env)
assert self.not_found(self.relu_regex, out, err)
assert self.found(self.ew_add_regex, out, err)
assert self.not_found(self.matmul_regex, out, err)
def test_flags_use_mkl_dnn_on_off(self):
env = {
str("FLAGS_tracer_mkldnn_ops_on"): str("elementwise_add"),
str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")
}
out = self.flags_use_mkl_dnn_common(env)
assert not self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
out, err = self.flags_use_mkl_dnn_common(env)
assert self.not_found(self.relu_regex, out, err)
assert self.found(self.ew_add_regex, out, err)
assert self.not_found(self.matmul_regex, out, err)
if __name__ == '__main__':
......
......@@ -19,6 +19,7 @@ import unittest
import os
import sys
import subprocess
import re
class TestFlagsUseMkldnn(unittest.TestCase):
......@@ -27,10 +28,22 @@ class TestFlagsUseMkldnn(unittest.TestCase):
self._python_interp += " check_flags_use_mkldnn.py"
self.env = os.environ.copy()
self.env[str("GLOG_v")] = str("3")
self.env[str("GLOG_v")] = str("1")
self.env[str("DNNL_VERBOSE")] = str("1")
self.env[str("FLAGS_use_mkldnn")] = str("1")
self.relu_regex = b"^dnnl_verbose,exec,cpu,eltwise,.+alg:eltwise_relu alpha:0 beta:0,10x20x30"
def _print_when_false(self, cond, out, err):
if not cond:
print('out', out)
print('err', err)
return cond
def found(self, regex, out, err):
_found = re.search(regex, out, re.MULTILINE)
return self._print_when_false(_found, out, err)
def test_flags_use_mkl_dnn(self):
cmd = self._python_interp
......@@ -43,15 +56,8 @@ class TestFlagsUseMkldnn(unittest.TestCase):
out, err = proc.communicate()
returncode = proc.returncode
print('out', out)
print('err', err)
assert returncode == 0
# in python3, type(out) is 'bytes', need use encode
assert out.find(
"dnnl_verbose,exec,cpu,eltwise,jit:avx512_common,forward_training,"
"data_f32::blocked:abc:f0 diff_undef::undef::f0,,alg:eltwise_relu".
encode()) != -1
assert self.found(self.relu_regex, out, err)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册