提交 9a15c923 编写于 作者: P pzelazko-intel 提交者: Tao Luo

bnorm+relu fuse for mkldnn (inference) (#11434)

* bnorm+relu fuse for mkldnn

* separate fuse_relu function

* bug fix

* proper while range in inference_transpiler

* description fix

* review fix

* review fix

* unit test for fwd batch norm+relu MKLDNN fuse
上级 ce5f1e05
......@@ -122,5 +122,9 @@ def parse_args():
type=str,
default="",
help='Directory that contains all the training recordio files.')
parser.add_argument(
'--use_inference_transpiler',
action='store_true',
help='If set, uses inference transpiler to optimize the program.')
args = parser.parse_args()
return args
......@@ -131,6 +131,11 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
exe = fluid.Executor(place)
exe.run(startup_prog)
# Use inference_transpiler to speedup
if args.use_inference_transpiler:
t = fluid.InferenceTranspiler()
t.transpile(infer_prog, place)
if not args.use_reader_op:
feed_var_list = [
var for var in train_prog.global_block().vars.itervalues()
......
......@@ -66,6 +66,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean");
......@@ -111,6 +112,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats;
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor
auto src_memory =
......
......@@ -155,6 +155,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_with_relu",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Batch Normalization.
......
......@@ -1993,7 +1993,8 @@ def batch_norm(input,
name=None,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=False):
do_model_average_for_mean_and_var=False,
fuse_with_relu=False):
"""
**Batch Normalization Layer**
......@@ -2036,6 +2037,7 @@ def batch_norm(input,
moving_mean_name(string, Default None): The name of moving_mean which store the global Mean.
moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance.
do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not.
fuse_with_relu (bool): if True, this OP performs relu after batch norm.
Returns:
Variable: A tensor variable which is the result after applying batch normalization on the input.
......@@ -2121,7 +2123,8 @@ def batch_norm(input,
"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test,
"use_mkldnn": use_mkldnn
"use_mkldnn": use_mkldnn,
"fuse_with_relu": fuse_with_relu
})
return helper.append_activation(batch_norm_out)
......
......@@ -52,5 +52,17 @@ class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference):
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference):
def init_kernel_type(self):
self.use_mkldnn = True
self.fuse_with_relu = True
def test_check_output(self):
place = core.CPUPlace()
data_format = "NCHW"
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
if __name__ == '__main__':
unittest.main()
......@@ -159,6 +159,7 @@ class TestBatchNormOpInference(unittest.TestCase):
def setUp(self):
self.dtype = np.float32
self.use_mkldnn = False
self.fuse_with_relu = False
self.init_kernel_type()
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
......@@ -180,6 +181,8 @@ class TestBatchNormOpInference(unittest.TestCase):
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(dtype)
# generate some negative values to test case with relu fused
x_val = x_val - 0.5
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
......@@ -188,6 +191,8 @@ class TestBatchNormOpInference(unittest.TestCase):
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, data_layout).astype(dtype)
if self.fuse_with_relu:
y_out = np.maximum(y_out, 0)
scope = core.Scope()
......@@ -233,6 +238,7 @@ class TestBatchNormOpInference(unittest.TestCase):
is_test=True,
data_layout=data_layout,
use_mkldnn=self.use_mkldnn,
fuse_with_relu=self.fuse_with_relu,
epsilon=epsilon)
batch_norm_op.run(scope, place)
......@@ -265,6 +271,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
def setUp(self):
self.dtype = np.float16
self.use_mkldnn = False
self.fuse_with_relu = False
self.init_kernel_type()
def test_check_output(self):
......@@ -284,6 +291,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
class TestBatchNormOpTraining(unittest.TestCase):
def setUp(self):
self.use_mkldnn = False
self.fuse_with_relu = False
self.data_formats = ["NCHW", "NHWC"]
self.init_kernel_type()
......@@ -367,7 +375,8 @@ class TestBatchNormOpTraining(unittest.TestCase):
"epsilon": epsilon,
"is_test": False,
"data_layout": data_layout,
"use_mkldnn": self.use_mkldnn
"use_mkldnn": self.use_mkldnn,
"fuse_with_relu": self.fuse_with_relu
})
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from .. import core
from ..framework import Program
......@@ -22,7 +23,10 @@ class InferenceTranspiler:
'''
Convert the fluid program to optimized inference program.
There are several optimizations, only fuse batch normalization is supported now.
There are several optimizations:
- fuse convolution and batch normalization
- fuse batch normalization and relu (MKLDNN only)
Examples:
......@@ -54,6 +58,51 @@ class InferenceTranspiler:
if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None")
self.fuse_batch_norm(program, place, scope)
self.fuse_relu_mkldnn(program)
def fuse_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
Relu activation following batch norm OP can be fused by adding
:math:`fuse_with_relu` attribute to batch norm OP.
The result of fuse is:
- before:
- batch_norm->relu->any_other_op
- after:
- batch_norm->any_other_op
:param program: program to transpile
:type program: Program
'''
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
if not use_mkldnn:
return
self.block = program.block(0)
i = 0
while i < len(self.block.ops) - 1:
current_op = self.block.ops[i]
if current_op.type in ['batch_norm']:
next_op = self.block.ops[i + 1]
if next_op.type == 'relu':
# modify bnorm OP to include relu
current_op.set_attr("fuse_with_relu", True)
# remove relu OP
self.block.remove_op(i + 1)
i = i + 1
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def fuse_batch_norm(self, program, place, scope):
'''
......@@ -107,7 +156,7 @@ class InferenceTranspiler:
self.input_map = {} # store the input names should be adjusted
i = 0
while i < len(self.block.ops):
while i < len(self.block.ops) - 2:
current_op = self.block.ops[i]
# TODO(luotao1): consider only conv2d now. fc would be delt later.
if current_op.type in ['conv2d']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册