未验证 提交 0ecf441a 编写于 作者: A arlesniak 提交者: GitHub

Add support for mkldnn ops types selection with FLAGS in dygraph (#27482)

* Add support for mkldnn ops types selection with FLAGS in dygraph

* use regex to match DNNL verbose

* python3 encoding fix
上级 e7b476c1
......@@ -22,6 +22,8 @@
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool(use_mkldnn);
DECLARE_string(tracer_mkldnn_ops_on);
DECLARE_string(tracer_mkldnn_ops_off);
namespace paddle {
namespace imperative {
......@@ -50,7 +52,17 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const platform::Place& place, bool trace_backward) {
VLOG(1) << "Trace Op: " << type;
if (FLAGS_use_mkldnn) {
attrs["use_mkldnn"] = true;
// if both lists are empty all ops are enabled (default for
// FLAGS_use_mkldnn=1)
// if ops_on list is not empty only ops from that list are enabled
if (!FLAGS_tracer_mkldnn_ops_on.empty()) {
auto is_on = FLAGS_tracer_mkldnn_ops_on.find(type) != std::string::npos;
attrs["use_mkldnn"] = is_on;
} else {
// if ops_on list is empty all ops are enabled except types from off_list
auto is_off = FLAGS_tracer_mkldnn_ops_off.find(type) != std::string::npos;
attrs["use_mkldnn"] = !is_off;
}
}
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
const auto& op_info = op->Info();
......
......@@ -536,3 +536,25 @@ DEFINE_int32(
"gradient accumulation, if the number of gradients need to that "
"less FLAGS_max_inplace_grad_add, than it will be use several grad_add"
"instead of sum. Default is 0.");
/**
* Debug related FLAG
* Name: tracer_mkldnn_ops_on
* Since Version: 2.0.0
* Value Range: string, default=empty
* Example:
* Note: Holds list of operation types with OneDNN kernels to be enabled.
*/
DEFINE_string(tracer_mkldnn_ops_on, "",
"List of OneDNN operation types to be turned on");
/**
* Debug related FLAG
* Name: tracer_mkldnn_ops_off
* Since Version: 2.0.0
* Value Range: string, default=empty
* Example:
* Note: Holds list of operation types with OneDNN kernels to be disabled.
*/
DEFINE_string(tracer_mkldnn_ops_off, "",
"List of OneDNN operation types to be turned off");
......@@ -31,6 +31,8 @@
// data processing
DECLARE_bool(use_mkldnn);
DECLARE_string(tracer_mkldnn_ops_on);
DECLARE_string(tracer_mkldnn_ops_off);
// debug
DECLARE_bool(check_nan_inf);
DECLARE_bool(cpu_deterministic);
......@@ -349,7 +351,8 @@ static void RegisterGlobalVarGetterSetter() {
FLAGS_init_allocated_mem, FLAGS_initial_cpu_memory_in_mb,
FLAGS_memory_fraction_of_eager_deletion, FLAGS_use_pinned_memory,
FLAGS_benchmark, FLAGS_inner_op_parallelism, FLAGS_tracer_profile_fname,
FLAGS_paddle_num_threads, FLAGS_use_mkldnn, FLAGS_max_inplace_grad_add);
FLAGS_paddle_num_threads, FLAGS_use_mkldnn, FLAGS_max_inplace_grad_add,
FLAGS_tracer_mkldnn_ops_on, FLAGS_tracer_mkldnn_ops_off);
#ifdef PADDLE_WITH_CUDA
REGISTER_PUBLIC_GLOBAL_VAR(
......
......@@ -210,6 +210,8 @@ def __bootstrap__():
if core.is_compiled_with_mkldnn():
read_env_flags.append('use_mkldnn')
read_env_flags.append('tracer_mkldnn_ops_on')
read_env_flags.append('tracer_mkldnn_ops_off')
if core.is_compiled_with_dist():
#env for rpc
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import unicode_literals
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import os
from paddle.fluid.layer_helper import LayerHelper
def check():
print("check: fluid.core.globals()['FLAGS_use_mkldnn']=",
fluid.core.globals()["FLAGS_use_mkldnn"])
print("check: fluid.get_flags('FLAGS_use_mkldnn')=",
fluid.get_flags(['FLAGS_use_mkldnn']))
print("check: DNNL_VERBOSE=", os.environ['DNNL_VERBOSE'])
print("check: FLAGS_tracer_mkldnn_ops_on=",
fluid.core.globals()['FLAGS_tracer_mkldnn_ops_on'])
print("check: FLAGS_tracer_mkldnn_ops_off=",
fluid.core.globals()['FLAGS_tracer_mkldnn_ops_off'])
a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32)
b_np = np.random.uniform(-5, 5, (10, 20, 30)).astype(np.float32)
helper = LayerHelper(fluid.unique_name.generate(str("test")), act="relu")
func = helper.append_activation
with fluid.dygraph.guard(fluid.core.CPUPlace()):
a = fluid.dygraph.to_variable(a_np)
b = fluid.dygraph.to_variable(b_np)
y = fluid.layers.elementwise_add(x=a, y=b)
y = fluid.layers.matmul(x=y, y=b, transpose_y=True)
res1 = func(y)
np_res = np.add(a_np, b_np)
np_res = np.matmul(np_res, np.transpose(b_np, (0, 2, 1)))
np_res = np.maximum(np_res, 0)
assert np.allclose(res1.numpy(), np_res, atol=1e-3)
if __name__ == '__main__':
try:
check()
except Exception as e:
print(e)
print(type(e))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import unicode_literals
from __future__ import print_function
import unittest
import os
import sys
import subprocess
import re
class TestFlagsUseMkldnn(unittest.TestCase):
def setUp(self):
self._python_interp = sys.executable
self._python_interp += " check_flags_mkldnn_ops_on_off.py"
self.env = os.environ.copy()
self.env[str("DNNL_VERBOSE")] = str("1")
self.env[str("FLAGS_use_mkldnn")] = str("1")
self.relu_regex = b"^dnnl_verbose,exec,cpu,eltwise,.+alg:eltwise_relu alpha:0 beta:0,10x20x20"
self.ew_add_regex = b"^dnnl_verbose,exec,cpu,binary.+alg:binary_add,10x20x30:10x20x30 10x20x30"
self.matmul_regex = b"^dnnl_verbose,exec,cpu,matmul,.*b10m20n20k30"
def flags_use_mkl_dnn_common(self, e):
cmd = self._python_interp
env = dict(self.env, **e)
proc = subprocess.Popen(
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env)
out, err = proc.communicate()
returncode = proc.returncode
assert returncode == 0
return out
def found(self, regex, out):
return re.search(regex, out, re.MULTILINE)
def test_flags_use_mkl_dnn_on_empty_off_empty(self):
out = self.flags_use_mkl_dnn_common({})
assert self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_on(self):
env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu")}
out = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out)
assert not self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_on_multiple(self):
env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu,elementwise_add")}
out = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_off(self):
env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")}
out = self.flags_use_mkl_dnn_common(env)
assert self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_off_multiple(self):
env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul,relu")}
out = self.flags_use_mkl_dnn_common(env)
assert not self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
def test_flags_use_mkl_dnn_on_off(self):
env = {
str("FLAGS_tracer_mkldnn_ops_on"): str("elementwise_add"),
str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")
}
out = self.flags_use_mkl_dnn_common(env)
assert not self.found(self.relu_regex, out)
assert self.found(self.ew_add_regex, out)
assert not self.found(self.matmul_regex, out)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册