diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index dde48a27459f6e5c47adcbd471a4702f990c7d99..2f802a775b52401a7d6c7ccaba12ce593b91cbe3 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -58,8 +58,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, if (!FLAGS_tracer_mkldnn_ops_on.empty()) { auto is_on = FLAGS_tracer_mkldnn_ops_on.find(type) != std::string::npos; attrs["use_mkldnn"] = is_on; - } else { // if ops_on list is empty all ops are enabled except types from - // off_list + } else { + // if ops_on list is empty all ops are enabled except types from off_list auto is_off = FLAGS_tracer_mkldnn_ops_off.find(type) != std::string::npos; attrs["use_mkldnn"] = !is_off; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_flags_mkldnn_ops_on_off.py b/python/paddle/fluid/tests/unittests/mkldnn/test_flags_mkldnn_ops_on_off.py index 0edd6deab27d0981740896f2eeefb6ef05e891d8..52dc71c0ed84e5a2b0e3c85c3cbf9e16a71ca240 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_flags_mkldnn_ops_on_off.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_flags_mkldnn_ops_on_off.py @@ -19,6 +19,7 @@ import unittest import os import sys import subprocess +import re class TestFlagsUseMkldnn(unittest.TestCase): @@ -30,12 +31,9 @@ class TestFlagsUseMkldnn(unittest.TestCase): self.env[str("DNNL_VERBOSE")] = str("1") self.env[str("FLAGS_use_mkldnn")] = str("1") - self.relu_str = "dnnl_verbose,exec,cpu,eltwise,jit:avx512_common,forward_training," \ - "data_f32::blocked:abc:f0 diff_undef::undef::f0,,alg:eltwise_relu alpha:0 beta:0,10x20x20" - self.ew_add_str = "dnnl_verbose,exec,cpu,binary,jit:uni,undef,src_f32::blocked:abc:f0 " \ - "src_f32::blocked:abc:f0 dst_f32::blocked:abc:f0,,alg:binary_add,10x20x30:10x20x30 10x20x30" - self.matmul_str = "dnnl_verbose,exec,cpu,matmul,gemm:jit,undef,src_f32::blocked:abc:f0 " \ - "wei_f32::blocked:acb:f0 dst_f32::blocked:abc:f0,,,b10m20n20k30" + self.relu_regex = r"^dnnl_verbose,exec,cpu,eltwise,.+alg:eltwise_relu alpha:0 beta:0,10x20x20" + self.ew_add_regex = r"^dnnl_verbose,exec,cpu,binary.+alg:binary_add,10x20x30:10x20x30 10x20x30" + self.matmul_regex = r"^dnnl_verbose,exec,cpu,matmul,.*b10m20n20k30" def flags_use_mkl_dnn_common(self, e): cmd = self._python_interp @@ -49,48 +47,55 @@ class TestFlagsUseMkldnn(unittest.TestCase): out, err = proc.communicate() returncode = proc.returncode - print('out', out) - print('err', err) - assert returncode == 0 return out - def found(self, res): - return res != -1 + def found(self, regex, out): + return re.search(regex, out, re.MULTILINE) def test_flags_use_mkl_dnn_on_empty_off_empty(self): - # in python3, type(out) is 'bytes', need use encode out = self.flags_use_mkl_dnn_common({}) - assert self.found(out.find(self.relu_str.encode())) - assert self.found(out.find(self.ew_add_str.encode())) - assert self.found(out.find(self.matmul_str.encode())) + assert self.found(self.relu_regex, out) + assert self.found(self.ew_add_regex, out) + assert self.found(self.matmul_regex, out) def test_flags_use_mkl_dnn_on(self): - # in python3, type(out) is 'bytes', need use encode env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu")} out = self.flags_use_mkl_dnn_common(env) - assert self.found(out.find(self.relu_str.encode())) - assert not self.found(out.find(self.ew_add_str.encode())) - assert not self.found(out.find(self.matmul_str.encode())) + assert self.found(self.relu_regex, out) + assert not self.found(self.ew_add_regex, out) + assert not self.found(self.matmul_regex, out) + + def test_flags_use_mkl_dnn_on_multiple(self): + env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu,elementwise_add")} + out = self.flags_use_mkl_dnn_common(env) + assert self.found(self.relu_regex, out) + assert self.found(self.ew_add_regex, out) + assert not self.found(self.matmul_regex, out) def test_flags_use_mkl_dnn_off(self): - # in python3, type(out) is 'bytes', need use encode env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")} out = self.flags_use_mkl_dnn_common(env) - assert self.found(out.find(self.relu_str.encode())) - assert self.found(out.find(self.ew_add_str.encode())) - assert not self.found(out.find(self.matmul_str.encode())) + assert self.found(self.relu_regex, out) + assert self.found(self.ew_add_regex, out) + assert not self.found(self.matmul_regex, out) + + def test_flags_use_mkl_dnn_off_multiple(self): + env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul,relu")} + out = self.flags_use_mkl_dnn_common(env) + assert not self.found(self.relu_regex, out) + assert self.found(self.ew_add_regex, out) + assert not self.found(self.matmul_regex, out) def test_flags_use_mkl_dnn_on_off(self): - # in python3, type(out) is 'bytes', need use encode env = { str("FLAGS_tracer_mkldnn_ops_on"): str("elementwise_add"), str("FLAGS_tracer_mkldnn_ops_off"): str("matmul") } out = self.flags_use_mkl_dnn_common(env) - assert not self.found(out.find(self.relu_str.encode())) - assert self.found(out.find(self.ew_add_str.encode())) - assert not self.found(out.find(self.matmul_str.encode())) + assert not self.found(self.relu_regex, out) + assert self.found(self.ew_add_regex, out) + assert not self.found(self.matmul_regex, out) if __name__ == '__main__':