From 9a15c92317e6ac938de0279c7506a28e3c100116 Mon Sep 17 00:00:00 2001 From: pzelazko-intel Date: Wed, 27 Jun 2018 12:06:52 +0200 Subject: [PATCH] bnorm+relu fuse for mkldnn (inference) (#11434) * bnorm+relu fuse for mkldnn * separate fuse_relu function * bug fix * proper while range in inference_transpiler * description fix * review fix * review fix * unit test for fwd batch norm+relu MKLDNN fuse --- benchmark/fluid/args.py | 4 + benchmark/fluid/fluid_benchmark.py | 5 + .../fluid/operators/batch_norm_mkldnn_op.cc | 2 + paddle/fluid/operators/batch_norm_op.cc | 3 + python/paddle/fluid/layers/nn.py | 7 +- .../unittests/test_batch_norm_mkldnn_op.py | 12 +++ .../tests/unittests/test_batch_norm_op.py | 11 ++- .../fluid/transpiler/inference_transpiler.py | 99 ++++++++++++++----- 8 files changed, 115 insertions(+), 28 deletions(-) mode change 100644 => 100755 benchmark/fluid/fluid_benchmark.py diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py index 68a3d42d7a..99c9d79b06 100644 --- a/benchmark/fluid/args.py +++ b/benchmark/fluid/args.py @@ -122,5 +122,9 @@ def parse_args(): type=str, default="", help='Directory that contains all the training recordio files.') + parser.add_argument( + '--use_inference_transpiler', + action='store_true', + help='If set, uses inference transpiler to optimize the program.') args = parser.parse_args() return args diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py old mode 100644 new mode 100755 index ece1102dce..dcd4d9ea95 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -131,6 +131,11 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, exe = fluid.Executor(place) exe.run(startup_prog) + # Use inference_transpiler to speedup + if args.use_inference_transpiler: + t = fluid.InferenceTranspiler() + t.transpile(infer_prog, place) + if not args.use_reader_op: feed_var_list = [ var for var in train_prog.global_block().vars.itervalues() diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index cc158e57f7..6ecb43c49c 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -66,6 +66,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { const float epsilon = ctx.Attr("epsilon"); const float momentum = ctx.Attr("momentum"); const bool is_test = ctx.Attr("is_test"); + const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); const auto *x = ctx.Input("X"); const auto *mean = ctx.Input("Mean"); @@ -111,6 +112,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { unsigned flags = mkldnn::use_scale_shift; if (is_test) flags |= mkldnn::use_global_stats; + if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; // create mkldnn memory from input x tensor auto src_memory = diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 52b0bf85c0..693bf973c2 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -155,6 +155,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_with_relu", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Batch Normalization. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 02ea2af325..64f48e259a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1993,7 +1993,8 @@ def batch_norm(input, name=None, moving_mean_name=None, moving_variance_name=None, - do_model_average_for_mean_and_var=False): + do_model_average_for_mean_and_var=False, + fuse_with_relu=False): """ **Batch Normalization Layer** @@ -2036,6 +2037,7 @@ def batch_norm(input, moving_mean_name(string, Default None): The name of moving_mean which store the global Mean. moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not. + fuse_with_relu (bool): if True, this OP performs relu after batch norm. Returns: Variable: A tensor variable which is the result after applying batch normalization on the input. @@ -2121,7 +2123,8 @@ def batch_norm(input, "momentum": momentum, "epsilon": epsilon, "is_test": is_test, - "use_mkldnn": use_mkldnn + "use_mkldnn": use_mkldnn, + "fuse_with_relu": fuse_with_relu }) return helper.append_activation(batch_norm_out) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py index f6097d4b84..18fa546159 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py @@ -52,5 +52,17 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference): self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) +class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference): + def init_kernel_type(self): + self.use_mkldnn = True + self.fuse_with_relu = True + + def test_check_output(self): + place = core.CPUPlace() + data_format = "NCHW" + + self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 01e5749bdb..a62ee9596d 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -159,6 +159,7 @@ class TestBatchNormOpInference(unittest.TestCase): def setUp(self): self.dtype = np.float32 self.use_mkldnn = False + self.fuse_with_relu = False self.init_kernel_type() def __assert_close(self, tensor, np_array, msg, atol=1e-4): @@ -180,6 +181,8 @@ class TestBatchNormOpInference(unittest.TestCase): scale_shape = [c] x_val = np.random.random_sample(x_shape).astype(dtype) + # generate some negative values to test case with relu fused + x_val = x_val - 0.5 scale_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -188,6 +191,8 @@ class TestBatchNormOpInference(unittest.TestCase): y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, epsilon, data_layout).astype(dtype) + if self.fuse_with_relu: + y_out = np.maximum(y_out, 0) scope = core.Scope() @@ -233,6 +238,7 @@ class TestBatchNormOpInference(unittest.TestCase): is_test=True, data_layout=data_layout, use_mkldnn=self.use_mkldnn, + fuse_with_relu=self.fuse_with_relu, epsilon=epsilon) batch_norm_op.run(scope, place) @@ -265,6 +271,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): def setUp(self): self.dtype = np.float16 self.use_mkldnn = False + self.fuse_with_relu = False self.init_kernel_type() def test_check_output(self): @@ -284,6 +291,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): class TestBatchNormOpTraining(unittest.TestCase): def setUp(self): self.use_mkldnn = False + self.fuse_with_relu = False self.data_formats = ["NCHW", "NHWC"] self.init_kernel_type() @@ -367,7 +375,8 @@ class TestBatchNormOpTraining(unittest.TestCase): "epsilon": epsilon, "is_test": False, "data_layout": data_layout, - "use_mkldnn": self.use_mkldnn + "use_mkldnn": self.use_mkldnn, + "fuse_with_relu": self.fuse_with_relu }) block.create_var(name='y@GRAD', dtype='float32', shape=y.shape) diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 0629f2916b..d32c69d148 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import numpy as np from .. import core from ..framework import Program @@ -20,12 +21,15 @@ from ..executor import global_scope class InferenceTranspiler: ''' - Convert the fluid program to optimized inference program. - - There are several optimizations, only fuse batch normalization is supported now. + Convert the fluid program to optimized inference program. + + There are several optimizations: + + - fuse convolution and batch normalization + - fuse batch normalization and relu (MKLDNN only) Examples: - + .. code-block:: python # As InferenceTranspiler will modify the original program, @@ -54,19 +58,64 @@ class InferenceTranspiler: if not isinstance(scope, core.Scope): raise TypeError("scope should be as Scope type or None") self.fuse_batch_norm(program, place, scope) + self.fuse_relu_mkldnn(program) + + def fuse_relu_mkldnn(self, program): + ''' + Transpile the program by fused relu activation for MKLDNN program. + + Relu activation following batch norm OP can be fused by adding + :math:`fuse_with_relu` attribute to batch norm OP. + + The result of fuse is: + + - before: + + - batch_norm->relu->any_other_op + + - after: + + - batch_norm->any_other_op + + :param program: program to transpile + :type program: Program + ''' + use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False)) + if not use_mkldnn: + return + + self.block = program.block(0) + + i = 0 + while i < len(self.block.ops) - 1: + current_op = self.block.ops[i] + if current_op.type in ['batch_norm']: + next_op = self.block.ops[i + 1] + if next_op.type == 'relu': + # modify bnorm OP to include relu + current_op.set_attr("fuse_with_relu", True) + # remove relu OP + self.block.remove_op(i + 1) + i = i + 1 + + self._remove_unused_var() + # TODO(luotao): use clone() method to flush the program.desc in force, + # since some large program.desc will not be flushed immediately. + # And a better solution will be considered later. + program = program.clone() def fuse_batch_norm(self, program, place, scope): ''' Transpile the program by fused batch normalization. - - The batch normalization followed the convolution or fully connected layer - can be integrated with them. Doing so will give us a forward acceleration, + + The batch normalization followed the convolution or fully connected layer + can be integrated with them. Doing so will give us a forward acceleration, especially in environments like mobile or embedded. - + For input :math:`X`: - - Conv process: :math:`X = input * W + bias` - - Batch norm process: :math:`X' = (X - mean) / std` + - Conv process: :math:`X = input * W + bias` + - Batch norm process: :math:`X' = (X - mean) / std` - Scale Process: :math:`Y = a * X' + b` After fuse into one operation: @@ -76,17 +125,17 @@ class InferenceTranspiler: Y &= (input * W + bias - mean) / std * a + b \\\\ &= input * a * W / std + ((bias - mean) / std * a + b) - The operator transformation is: + The operator transformation is: - before: - conv->batch_norm->any_other_op (bias == 0) - conv->elementwise_add->batch_norm->any_other_op (bias != 0) - - - after: + + - after: - conv->elementwise_add->any_other_op - + The transpile stages are: 1. insert elementwise_add op when bias == 0. @@ -99,20 +148,20 @@ class InferenceTranspiler: program (Program): program to transpile place (Place): inference place scope (Scope): inference Scope - + ''' self.scope = scope self.place = place self.block = program.block(0) - self.input_map = {} # store the input names should be adjusted + self.input_map = {} # store the input names should be adjusted i = 0 - while i < len(self.block.ops): + while i < len(self.block.ops) - 2: current_op = self.block.ops[i] # TODO(luotao1): consider only conv2d now. fc would be delt later. if current_op.type in ['conv2d']: - # TODO(luotao1): consider single chain network now. - # For branch network, we counldn't use block.ops[i + 1] as + # TODO(luotao1): consider single chain network now. + # For branch network, we counldn't use block.ops[i + 1] as # the judgment condition. next_op = self.block.ops[i + 1] # conv2d without bias @@ -137,17 +186,17 @@ class InferenceTranspiler: self._adjust_input() self._remove_unused_var() - # TODO(luotao): use clone() method to flush the program.desc in force, - # since some large program.desc will not be flushed immediately. + # TODO(luotao): use clone() method to flush the program.desc in force, + # since some large program.desc will not be flushed immediately. # And a better solution will be considered later. program = program.clone() # ====================== private transpiler functions ===================== def _insert_bias_op(self, index, current_op, bn_op): ''' - Construct elementwise_add operator for adding bias + Construct elementwise_add operator for adding bias and insert it into program. - + :param index: insert location of bias_op :type index: Int :param current_op: current operator (conv or fc) @@ -175,14 +224,14 @@ class InferenceTranspiler: def _fuse_param(self, current_op, bn_op, bias_op, with_bias): ''' fuse the batch_norm_op' parameters to current_op (conv or fc) - + :param current_op: current operator (conv or fc) :type current_op: Operator :param bn_op: batch norm operator :type bn_op: Operator :param bias_op: elementwise_add operator for adding bias :type bias_op: Operator - :param with_bias: If current operator has bias, with_bias = 1; otherwise 0. + :param with_bias: If current operator has bias, with_bias = 1; otherwise 0. :type with_bias: Int ''' -- GitLab