提交 9a15c923 编写于 作者: P pzelazko-intel 提交者: Tao Luo

bnorm+relu fuse for mkldnn (inference) (#11434)

* bnorm+relu fuse for mkldnn

* separate fuse_relu function

* bug fix

* proper while range in inference_transpiler

* description fix

* review fix

* review fix

* unit test for fwd batch norm+relu MKLDNN fuse
上级 ce5f1e05
...@@ -122,5 +122,9 @@ def parse_args(): ...@@ -122,5 +122,9 @@ def parse_args():
type=str, type=str,
default="", default="",
help='Directory that contains all the training recordio files.') help='Directory that contains all the training recordio files.')
parser.add_argument(
'--use_inference_transpiler',
action='store_true',
help='If set, uses inference transpiler to optimize the program.')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -131,6 +131,11 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -131,6 +131,11 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
# Use inference_transpiler to speedup
if args.use_inference_transpiler:
t = fluid.InferenceTranspiler()
t.transpile(infer_prog, place)
if not args.use_reader_op: if not args.use_reader_op:
feed_var_list = [ feed_var_list = [
var for var in train_prog.global_block().vars.itervalues() var for var in train_prog.global_block().vars.itervalues()
......
...@@ -66,6 +66,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -66,6 +66,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum"); const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean"); const auto *mean = ctx.Input<Tensor>("Mean");
...@@ -111,6 +112,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -111,6 +112,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
unsigned flags = mkldnn::use_scale_shift; unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats; if (is_test) flags |= mkldnn::use_global_stats;
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
auto src_memory = auto src_memory =
......
...@@ -155,6 +155,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -155,6 +155,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_with_relu",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Batch Normalization. Batch Normalization.
......
...@@ -1993,7 +1993,8 @@ def batch_norm(input, ...@@ -1993,7 +1993,8 @@ def batch_norm(input,
name=None, name=None,
moving_mean_name=None, moving_mean_name=None,
moving_variance_name=None, moving_variance_name=None,
do_model_average_for_mean_and_var=False): do_model_average_for_mean_and_var=False,
fuse_with_relu=False):
""" """
**Batch Normalization Layer** **Batch Normalization Layer**
...@@ -2036,6 +2037,7 @@ def batch_norm(input, ...@@ -2036,6 +2037,7 @@ def batch_norm(input,
moving_mean_name(string, Default None): The name of moving_mean which store the global Mean. moving_mean_name(string, Default None): The name of moving_mean which store the global Mean.
moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance.
do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not. do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not.
fuse_with_relu (bool): if True, this OP performs relu after batch norm.
Returns: Returns:
Variable: A tensor variable which is the result after applying batch normalization on the input. Variable: A tensor variable which is the result after applying batch normalization on the input.
...@@ -2121,7 +2123,8 @@ def batch_norm(input, ...@@ -2121,7 +2123,8 @@ def batch_norm(input,
"momentum": momentum, "momentum": momentum,
"epsilon": epsilon, "epsilon": epsilon,
"is_test": is_test, "is_test": is_test,
"use_mkldnn": use_mkldnn "use_mkldnn": use_mkldnn,
"fuse_with_relu": fuse_with_relu
}) })
return helper.append_activation(batch_norm_out) return helper.append_activation(batch_norm_out)
......
...@@ -52,5 +52,17 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference): ...@@ -52,5 +52,17 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference):
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference):
def init_kernel_type(self):
self.use_mkldnn = True
self.fuse_with_relu = True
def test_check_output(self):
place = core.CPUPlace()
data_format = "NCHW"
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -159,6 +159,7 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -159,6 +159,7 @@ class TestBatchNormOpInference(unittest.TestCase):
def setUp(self): def setUp(self):
self.dtype = np.float32 self.dtype = np.float32
self.use_mkldnn = False self.use_mkldnn = False
self.fuse_with_relu = False
self.init_kernel_type() self.init_kernel_type()
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
...@@ -180,6 +181,8 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -180,6 +181,8 @@ class TestBatchNormOpInference(unittest.TestCase):
scale_shape = [c] scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(dtype) x_val = np.random.random_sample(x_shape).astype(dtype)
# generate some negative values to test case with relu fused
x_val = x_val - 0.5
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32)
...@@ -188,6 +191,8 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -188,6 +191,8 @@ class TestBatchNormOpInference(unittest.TestCase):
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, data_layout).astype(dtype) epsilon, data_layout).astype(dtype)
if self.fuse_with_relu:
y_out = np.maximum(y_out, 0)
scope = core.Scope() scope = core.Scope()
...@@ -233,6 +238,7 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -233,6 +238,7 @@ class TestBatchNormOpInference(unittest.TestCase):
is_test=True, is_test=True,
data_layout=data_layout, data_layout=data_layout,
use_mkldnn=self.use_mkldnn, use_mkldnn=self.use_mkldnn,
fuse_with_relu=self.fuse_with_relu,
epsilon=epsilon) epsilon=epsilon)
batch_norm_op.run(scope, place) batch_norm_op.run(scope, place)
...@@ -265,6 +271,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): ...@@ -265,6 +271,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
def setUp(self): def setUp(self):
self.dtype = np.float16 self.dtype = np.float16
self.use_mkldnn = False self.use_mkldnn = False
self.fuse_with_relu = False
self.init_kernel_type() self.init_kernel_type()
def test_check_output(self): def test_check_output(self):
...@@ -284,6 +291,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): ...@@ -284,6 +291,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
class TestBatchNormOpTraining(unittest.TestCase): class TestBatchNormOpTraining(unittest.TestCase):
def setUp(self): def setUp(self):
self.use_mkldnn = False self.use_mkldnn = False
self.fuse_with_relu = False
self.data_formats = ["NCHW", "NHWC"] self.data_formats = ["NCHW", "NHWC"]
self.init_kernel_type() self.init_kernel_type()
...@@ -367,7 +375,8 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -367,7 +375,8 @@ class TestBatchNormOpTraining(unittest.TestCase):
"epsilon": epsilon, "epsilon": epsilon,
"is_test": False, "is_test": False,
"data_layout": data_layout, "data_layout": data_layout,
"use_mkldnn": self.use_mkldnn "use_mkldnn": self.use_mkldnn,
"fuse_with_relu": self.fuse_with_relu
}) })
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape) block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import numpy as np import numpy as np
from .. import core from .. import core
from ..framework import Program from ..framework import Program
...@@ -22,7 +23,10 @@ class InferenceTranspiler: ...@@ -22,7 +23,10 @@ class InferenceTranspiler:
''' '''
Convert the fluid program to optimized inference program. Convert the fluid program to optimized inference program.
There are several optimizations, only fuse batch normalization is supported now. There are several optimizations:
- fuse convolution and batch normalization
- fuse batch normalization and relu (MKLDNN only)
Examples: Examples:
...@@ -54,6 +58,51 @@ class InferenceTranspiler: ...@@ -54,6 +58,51 @@ class InferenceTranspiler:
if not isinstance(scope, core.Scope): if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None") raise TypeError("scope should be as Scope type or None")
self.fuse_batch_norm(program, place, scope) self.fuse_batch_norm(program, place, scope)
self.fuse_relu_mkldnn(program)
def fuse_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
Relu activation following batch norm OP can be fused by adding
:math:`fuse_with_relu` attribute to batch norm OP.
The result of fuse is:
- before:
- batch_norm->relu->any_other_op
- after:
- batch_norm->any_other_op
:param program: program to transpile
:type program: Program
'''
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
if not use_mkldnn:
return
self.block = program.block(0)
i = 0
while i < len(self.block.ops) - 1:
current_op = self.block.ops[i]
if current_op.type in ['batch_norm']:
next_op = self.block.ops[i + 1]
if next_op.type == 'relu':
# modify bnorm OP to include relu
current_op.set_attr("fuse_with_relu", True)
# remove relu OP
self.block.remove_op(i + 1)
i = i + 1
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def fuse_batch_norm(self, program, place, scope): def fuse_batch_norm(self, program, place, scope):
''' '''
...@@ -107,7 +156,7 @@ class InferenceTranspiler: ...@@ -107,7 +156,7 @@ class InferenceTranspiler:
self.input_map = {} # store the input names should be adjusted self.input_map = {} # store the input names should be adjusted
i = 0 i = 0
while i < len(self.block.ops): while i < len(self.block.ops) - 2:
current_op = self.block.ops[i] current_op = self.block.ops[i]
# TODO(luotao1): consider only conv2d now. fc would be delt later. # TODO(luotao1): consider only conv2d now. fc would be delt later.
if current_op.type in ['conv2d']: if current_op.type in ['conv2d']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册