提交 772faf06 编写于 作者: A arlesniak

use regex to match DNNL verbose

上级 622e3a44
...@@ -58,8 +58,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -58,8 +58,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
if (!FLAGS_tracer_mkldnn_ops_on.empty()) { if (!FLAGS_tracer_mkldnn_ops_on.empty()) {
auto is_on = FLAGS_tracer_mkldnn_ops_on.find(type) != std::string::npos; auto is_on = FLAGS_tracer_mkldnn_ops_on.find(type) != std::string::npos;
attrs["use_mkldnn"] = is_on; attrs["use_mkldnn"] = is_on;
} else { // if ops_on list is empty all ops are enabled except types from } else {
// off_list // if ops_on list is empty all ops are enabled except types from off_list
auto is_off = FLAGS_tracer_mkldnn_ops_off.find(type) != std::string::npos; auto is_off = FLAGS_tracer_mkldnn_ops_off.find(type) != std::string::npos;
attrs["use_mkldnn"] = !is_off; attrs["use_mkldnn"] = !is_off;
} }
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import os import os
import sys import sys
import subprocess import subprocess
import re
class TestFlagsUseMkldnn(unittest.TestCase): class TestFlagsUseMkldnn(unittest.TestCase):
...@@ -30,12 +31,9 @@ class TestFlagsUseMkldnn(unittest.TestCase): ...@@ -30,12 +31,9 @@ class TestFlagsUseMkldnn(unittest.TestCase):
self.env[str("DNNL_VERBOSE")] = str("1") self.env[str("DNNL_VERBOSE")] = str("1")
self.env[str("FLAGS_use_mkldnn")] = str("1") self.env[str("FLAGS_use_mkldnn")] = str("1")
self.relu_str = "dnnl_verbose,exec,cpu,eltwise,jit:avx512_common,forward_training," \ self.relu_regex = r"^dnnl_verbose,exec,cpu,eltwise,.+alg:eltwise_relu alpha:0 beta:0,10x20x20"
"data_f32::blocked:abc:f0 diff_undef::undef::f0,,alg:eltwise_relu alpha:0 beta:0,10x20x20" self.ew_add_regex = r"^dnnl_verbose,exec,cpu,binary.+alg:binary_add,10x20x30:10x20x30 10x20x30"
self.ew_add_str = "dnnl_verbose,exec,cpu,binary,jit:uni,undef,src_f32::blocked:abc:f0 " \ self.matmul_regex = r"^dnnl_verbose,exec,cpu,matmul,.*b10m20n20k30"
"src_f32::blocked:abc:f0 dst_f32::blocked:abc:f0,,alg:binary_add,10x20x30:10x20x30 10x20x30"
self.matmul_str = "dnnl_verbose,exec,cpu,matmul,gemm:jit,undef,src_f32::blocked:abc:f0 " \
"wei_f32::blocked:acb:f0 dst_f32::blocked:abc:f0,,,b10m20n20k30"
def flags_use_mkl_dnn_common(self, e): def flags_use_mkl_dnn_common(self, e):
cmd = self._python_interp cmd = self._python_interp
...@@ -49,48 +47,55 @@ class TestFlagsUseMkldnn(unittest.TestCase): ...@@ -49,48 +47,55 @@ class TestFlagsUseMkldnn(unittest.TestCase):
out, err = proc.communicate() out, err = proc.communicate()
returncode = proc.returncode returncode = proc.returncode
print('out', out)
print('err', err)
assert returncode == 0 assert returncode == 0
return out return out
def found(self, res): def found(self, regex, out):
return res != -1 return re.search(regex, out, re.MULTILINE)
def test_flags_use_mkl_dnn_on_empty_off_empty(self): def test_flags_use_mkl_dnn_on_empty_off_empty(self):
# in python3, type(out) is 'bytes', need use encode
out = self.flags_use_mkl_dnn_common({}) out = self.flags_use_mkl_dnn_common({})
assert self.found(out.find(self.relu_str.encode())) assert self.found(self.relu_regex, out)
assert self.found(out.find(self.ew_add_str.encode())) assert self.found(self.ew_add_regex, out)
assert self.found(out.find(self.matmul_str.encode())) assert self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_on(self): def test_flags_use_mkl_dnn_on(self):
# in python3, type(out) is 'bytes', need use encode
env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu")} env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu")}
out = self.flags_use_mkl_dnn_common(env) out = self.flags_use_mkl_dnn_common(env)
assert self.found(out.find(self.relu_str.encode())) assert self.found(self.relu_regex, out)
assert not self.found(out.find(self.ew_add_str.encode())) assert not self.found(self.ew_add_regex, out)
assert not self.found(out.find(self.matmul_str.encode())) assert not self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_on_multiple(self):
env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu,elementwise_add")}
out = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_off(self): def test_flags_use_mkl_dnn_off(self):
# in python3, type(out) is 'bytes', need use encode
env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")} env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")}
out = self.flags_use_mkl_dnn_common(env) out = self.flags_use_mkl_dnn_common(env)
assert self.found(out.find(self.relu_str.encode())) assert self.found(self.relu_regex, out)
assert self.found(out.find(self.ew_add_str.encode())) assert self.found(self.ew_add_regex, out)
assert not self.found(out.find(self.matmul_str.encode())) assert not self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_off_multiple(self):
env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul,relu")}
out = self.flags_use_mkl_dnn_common(env)
assert not self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_on_off(self): def test_flags_use_mkl_dnn_on_off(self):
# in python3, type(out) is 'bytes', need use encode
env = { env = {
str("FLAGS_tracer_mkldnn_ops_on"): str("elementwise_add"), str("FLAGS_tracer_mkldnn_ops_on"): str("elementwise_add"),
str("FLAGS_tracer_mkldnn_ops_off"): str("matmul") str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")
} }
out = self.flags_use_mkl_dnn_common(env) out = self.flags_use_mkl_dnn_common(env)
assert not self.found(out.find(self.relu_str.encode())) assert not self.found(self.relu_regex, out)
assert self.found(out.find(self.ew_add_str.encode())) assert self.found(self.ew_add_regex, out)
assert not self.found(out.find(self.matmul_str.encode())) assert not self.found(self.matmul_regex, out)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册