diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index c805d15fbbf99381ce84731c12ca2be8b85ecd81..c951de5dd5a7b66ce03c705e9bdcbe3f5c3e565d 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -91,7 +91,6 @@ void ReduceOpHandle::RunImpl() { if (paddle::platform::is_cpu_place(pre_place)) { ReduceLoDTensor func(lod_tensors, trg); VisitDataType(ToDataType(lod_tensors[0].type()), func); - } else if (paddle::platform::is_gpu_place(pre_place)) { #ifdef PADDLE_WITH_CUDA auto out_p = out_var_handles[0]->place_; diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index d7a8f918ed2b377be867f9b568434f9a96f7deec..63d371310d2a26a1460e527fc51923dfd6e0b8bc 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_md = platform::MKLDNNMemDesc( dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); - auto src_memory = mkldnn::memory({src_md, mkldnn_engine}, - reinterpret_cast(input_data)); - auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine}, - reinterpret_cast(filter_data)); + auto src_memory = + mkldnn::memory({src_md, mkldnn_engine}, + reinterpret_cast(const_cast(input_data))); + auto weights_memory = + mkldnn::memory({weights_md, mkldnn_engine}, + reinterpret_cast(const_cast(filter_data))); auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data); std::shared_ptr conv_pd = @@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); // create memory - auto diff_dst_memory = - mkldnn::memory({diff_weights_md, mkldnn_engine}, - reinterpret_cast(output_grad_data)); + auto diff_dst_memory = mkldnn::memory( + {diff_weights_md, mkldnn_engine}, + reinterpret_cast(const_cast(output_grad_data))); // Retrieve conv_pd from device context auto conv_pd = std::static_pointer_cast( @@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto diff_weights_memory = mkldnn::memory({diff_weights_md, mkldnn_engine}, reinterpret_cast(filter_grad_data)); - auto src_memory = mkldnn::memory({src_md, mkldnn_engine}, - reinterpret_cast(input_data)); + auto src_memory = + mkldnn::memory({src_md, mkldnn_engine}, + reinterpret_cast(const_cast(input_data))); // create backward conv primitive for weights auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights( @@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { strides, paddings, *conv_pd, mkldnn_engine); // create memory - auto diff_src_memory = - mkldnn::memory({diff_src_md, mkldnn_engine}, - reinterpret_cast(input_grad_data)); - auto weights_memory = mkldnn::memory( - {weights_md, mkldnn_engine}, reinterpret_cast(filter_data)); + auto diff_src_memory = mkldnn::memory( + {diff_src_md, mkldnn_engine}, + reinterpret_cast(const_cast(input_grad_data))); + auto weights_memory = + mkldnn::memory({weights_md, mkldnn_engine}, + reinterpret_cast(const_cast(filter_data))); // create backward conv primitive for data auto conv_bwd_data_prim = mkldnn::convolution_backward_data( diff --git a/paddle/fluid/operators/softmax_mkldnn_op.cc b/paddle/fluid/operators/softmax_mkldnn_op.cc index dc2f1763446b2aaf72b20c72e8e37ec920abd120..d00bd1447e6114b6000b65799abb566a2a510127 100644 --- a/paddle/fluid/operators/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/softmax_mkldnn_op.cc @@ -73,6 +73,15 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { softmax_dst_memory); std::vector pipeline{softmax}; stream(stream::kind::eager).submit(pipeline).wait(); + + const bool is_test = ctx.Attr("is_test"); + if (!is_test) { + T threshold = exp(-64); + for (size_t i = 0; i < dst_tz[0] * dst_tz[1]; ++i) { + output_data[i] = + output_data[i] < threshold ? threshold : output_data[i]; + } + } } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 6bdefc0f23910c90f3878d8f2634ca6e03c6f736..e1f286f9ba42ff22fffbfc012832dd751a37c1d0 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -97,6 +97,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("is_test", + "Disable epsilon adding to softmax results. Used by MKLDNN.") + .SetDefault(false); AddComment(R"DOC( Softmax Operator. diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index f757411b853bacb9e03fc42fa2ef6593c3cde00f..e9ca0d45f98bd27692a15060310d4e8cd1e8b181 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -37,6 +37,7 @@ from distribute_transpiler import DistributeTranspiler from distribute_transpiler_simple import SimpleDistributeTranspiler from concurrency import (Go, make_channel, channel_send, channel_recv, channel_close, Select) +from inference_transpiler import InferenceTranspiler import clip from memory_optimization_transpiler import memory_optimize, release_memory import profiler @@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ 'clip', 'SimpleDistributeTranspiler', 'DistributeTranspiler', + 'InferenceTranspiler', 'memory_optimize', 'release_memory', 'profiler', diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py new file mode 100644 index 0000000000000000000000000000000000000000..39b01610f96018e1775405a30147e77006cecc16 --- /dev/null +++ b/python/paddle/fluid/inference_transpiler.py @@ -0,0 +1,240 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from framework import Program +from executor import global_scope +from . import core + + +class InferenceTranspiler: + def transpile(self, program, place, scope=None): + ''' + Transpile the program. Support only fuse batch normalization now. + + :param program: program to transpile + :type program: Program + :param place: inference place + :type place: Place + :param scope: inference scope + :type scope: Scope or None + ''' + if not isinstance(program, Program): + raise TypeError("program should be as Program type") + if not isinstance(place, core.CPUPlace) and not isinstance( + place, core.CUDAPlace): + raise TypeError("place should be as CPUPlace/CUDAPlace type") + if scope is None: + scope = global_scope() + if not isinstance(scope, core.Scope): + raise TypeError("scope should be as Scope type or None") + self.fuse_batch_norm(program, place, scope) + + def fuse_batch_norm(self, program, place, scope): + ''' + Transpile the program by fused batch normalization. + + The batch normalization followed the convolution or fully connected layer + can be integrated with them. Doing so will give us a forward acceleration, + especially in environments like mobile or embedded. + + For input X: + - Conv process: X = input * W + bias + - Batch norm process: X' = (X - mean) / std + - Scale Process: Y = a * X' + b + + After fuse into one operation: + + Y = (input * W + bias - mean) / std * a + b + = input * a * W / std + ((bias - mean) / std * a + b) + + The operator transformation is: + - before: + - conv->batch_norm->any_other_op (bias == 0) + - conv->elementwise_add->batch_norm->any_other_op (bias != 0) + - after: + - conv->elementwise_add->any_other_op + + The transpile stages are: + 1. insert elementwise_add op when bias == 0. + 2. fuse the batch_norm's parameters to conv and elementwise_add operators. + 3. remove batch_norm ops which are not used in any other ops. + 4. adjust the input of any_other_op to be the output of elementwise_add operator. + 5. remove unused variables. + + :param program: program to transpile + :type program: Program + :param place: inference place + :type place: Place + :param scope: inference scope + :type scope: Scope + ''' + self.scope = scope + self.place = place + self.block = program.block(0) + self.input_map = {} # store the input names should be adjusted + + i = 0 + while i < len(self.block.ops): + current_op = self.block.ops[i] + # TODO(luotao1): consider only conv2d now. fc would be delt later. + if current_op.type in ['conv2d']: + # TODO(luotao1): consider single chain network now. + # For branch network, we counldn't use block.ops[i + 1] as + # the judgment condition. + next_op = self.block.ops[i + 1] + # conv2d without bias + if (next_op.type == 'batch_norm'): + # insert bias op + bias_op = self._insert_bias_op(i + 1, current_op, next_op) + # fuse batch_norm + self._fuse_param(current_op, next_op, bias_op, 0) + # remove batch_norm_op + self.block.remove_op(i + 2) + i = i + 1 + # conv2d with bias, the next_op.type is elementwise_add + elif (next_op.type == 'elementwise_add'): + next_next_op = self.block.ops[i + 2] + if (next_next_op.type == 'batch_norm'): + # fuse batch_norm + self._fuse_param(current_op, next_next_op, next_op, 1) + # remove batch_norm_op + self.block.remove_op(i + 2) + i = i + 1 + i = i + 1 + + self._adjust_input() + self._remove_unused_var() + # TODO(luotao): use clone() method to flush the program.desc in force, + # since some large program.desc will not be flushed immediately. + # And a better solution will be considered later. + program = program.clone() + + # ====================== private transpiler functions ===================== + def _insert_bias_op(self, index, current_op, bn_op): + ''' + Construct elementwise_add operator for adding bias + and insert it into program. + + :param index: insert location of bias_op + :type index: Int + :param current_op: current operator (conv or fc) + :type current_op: Operator + :param bn_op: batch norm operator + :type bn_op: Operator + :return: bias_op + :rtype: Operator + ''' + # The input of bias_op is current_op's output and Bias of bn_op + # The output of bias_op is bn_op's output + x_var = self.block.var(current_op.output("Output")[0]) + y_var = self.block.var(bn_op.input("Bias")[0]) + out_var = self.block.var(bn_op.output("Y")[0]) + + bias_op = self.block.insert_op( + index, + type="elementwise_add", + inputs={"X": x_var, + "Y": y_var}, + outputs={"Out": out_var}, + attrs={"axis": 1}) # dim_start=1 + return bias_op + + def _fuse_param(self, current_op, bn_op, bias_op, with_bias): + ''' + fuse the batch_norm_op' parameters to current_op (conv or fc) + + :param current_op: current operator (conv or fc) + :type current_op: Operator + :param bn_op: batch norm operator + :type bn_op: Operator + :param bias_op: elementwise_add operator for adding bias + :type bias_op: Operator + :param with_bias: If current operator has bias, with_bias = 1; otherwise 0. + :type with_bias: Int + ''' + + def _update_param(op, old_param_name, new_param): + # For the sake of remaining the original variables the same as before, + # create new variables in scope to store the new parameters. + old_param_name = old_param_name[0] + old_var = self.block.vars[old_param_name] + new_param_name = old_param_name + '_fuse_bn' + new_var = self.block.create_parameter( + name=new_param_name.encode('ascii'), + type=old_var.type, + dtype=old_var.dtype, + shape=old_var.shape) + op.rename_input(old_param_name, new_param_name) + self.scope.var(new_param_name) + + tensor = self.scope.find_var(new_param_name).get_tensor() + tensor.set(np.array(new_param), self.place) + + def _load_param(param_name): + return np.array(self.scope.find_var(param_name[0]).get_tensor()) + + bias_bn = _load_param(bn_op.input("Bias")) #Bias + scale_bn = _load_param(bn_op.input("Scale")) #Scale + mean_bn = _load_param(bn_op.input("Mean")) #Mean + var_bn = _load_param(bn_op.input("Variance")) #Variance + + # TODO(luotao1): consider only conv2d now. fc would be delt later. + current_param = _load_param(current_op.input("Filter")) + std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5))) + tmp = np.float32(np.divide(scale_bn, std_bn)) + + # add bias of batch_norm_op to conv2d + if with_bias: + bias = _load_param(bias_op.input("Y")) + else: + bias = np.zeros(bias_bn.shape) + bias = np.float32( + np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn)) + + # re-compute weight of conv2d + tmp = tmp.reshape(tmp.shape[0], -1) + dst_param = current_param.reshape((tmp.shape[0], -1)) + dst_param = np.float32(np.multiply(dst_param, tmp)) + dst_param = dst_param.reshape(current_param.shape) + + # update parameters + _update_param(current_op, current_op.input("Filter"), dst_param) + _update_param(bias_op, bias_op.input("Y"), bias) + + # collect the renamed input + self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0] + + def _adjust_input(self): + for i in range(len(self.block.ops)): + current_op = self.block.ops[i] + for input_arg in current_op.input_arg_names: + if input_arg in self.input_map: + current_op.rename_input(input_arg, + self.input_map[input_arg]) + + def _remove_unused_var(self): + ''' + remove unused varibles in program + ''' + args = [] + for i in range(len(self.block.ops)): + current_op = self.block.ops[i] + args += current_op.input_arg_names + args += current_op.output_arg_names + args = list(set(args)) # unique the input and output arguments + + for var in self.block.vars.keys(): + if var not in args: + self.block.remove_var(var) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bba8b64bd88c3edc6eda110dde38c0ced50439f6..2993cb973456836ab124cdb267dbb92c45fcecbc 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -88,6 +88,7 @@ def fc(input, bias_attr=None, use_mkldnn=False, act=None, + is_test=False, name=None): """ **Fully Connected Layer** @@ -134,6 +135,7 @@ def fc(input, bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias of this layer. If it is set to None, no bias will be added to the output units. act (str, default None): Activation to be applied to the output of this layer. + is_test(bool): A flag indicating whether execution is in test phase. use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn library is installed. Default: False name (str, default None): The name of this layer. @@ -177,8 +179,11 @@ def fc(input, inputs={"Input": input, "W": w}, outputs={"Out": tmp}, - attrs={"use_mkldnn": use_mkldnn, - "bias_attr": bias_attr}) + attrs={ + "use_mkldnn": use_mkldnn, + "is_test": is_test, + "bias_attr": bias_attr + }) return helper.append_activation(tmp) else: for input_var, param_attr in helper.iter_inputs_and_params(): diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index e8bb082be196b6342b1719235f1264bbe3d776ac..0027b651e88b68950e77e03399b3987aa0120192 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -22,10 +22,17 @@ import sys import numpy import unittest import os +import numpy as np def resnet_cifar10(input, depth=32): - def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): tmp = fluid.layers.conv2d( input=input, filter_size=filter_size, @@ -33,7 +40,7 @@ def resnet_cifar10(input, depth=32): stride=stride, padding=padding, act=None, - bias_attr=False) + bias_attr=bias_attr) return fluid.layers.batch_norm(input=tmp, act=act) def shortcut(input, ch_in, ch_out, stride): @@ -44,7 +51,7 @@ def resnet_cifar10(input, depth=32): def basicblock(input, ch_in, ch_out, stride): tmp = conv_bn_layer(input, ch_out, 3, stride, 1) - tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) short = shortcut(input, ch_in, ch_out, stride) return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') @@ -219,11 +226,26 @@ def infer(use_cuda, save_dirname=None): batch_size = 1 tensor_img = numpy.random.rand(batch_size, 3, 32, 32).astype("float32") + # Use inference_transpiler to speedup + inference_transpiler_program = inference_program.clone() + t = fluid.InferenceTranspiler() + t.transpile(inference_transpiler_program, place) + # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. results = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets) + + transpiler_results = exe.run(inference_transpiler_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + + assert len(results[0]) == len(transpiler_results[0]) + for i in range(len(results[0])): + np.testing.assert_almost_equal( + results[0][i], transpiler_results[0][i], decimal=6) + print("infer results: ", results[0])