提交 84aa7b5f 编写于 作者: W Wang, Zhiming 提交者: François Chollet

Non training Batch Norm operator has bad performance for it running into...

Non training Batch Norm operator has bad performance for it running into tensorflow's non fused batch norm API (#10207)

* When use tensorflow as backend, let batch norm run into fused batch norm as much as possible, which has better performance.

fix issue: http://github.com/keras-team/keras/issues/10058

* In Tensorflow backend, let batch norm call to FusedBatchNorm only NHWC format, also gamma and beta are not None.

Test result:
test env: with Tensorflow(commit a543d9471047ca3f6881c87105fcbe2cdff9207d Date:   Thu May 10 17:43:30 2018, local build), python3.4, centos7.4
test cases:
  "pytest  ./tests/keras/layers/normalization_test.py"  <all passed>
  "pytest  ./tests"      <keep same result as without this commit's modification on BN>

* fix code sytle.

* 1. Add axis parameter in backend's batch_normalization functions.
2. Refine the batch_normalization function in tensorflow backend, Let's it call to fused batch norm as much as possible.

Thanks the coments from fchollet.

* Trigger

* 1. add default value -1 for parameter axis in batch_normalization function in backend.
2. fix some code style.
Thanks the comments from fchollet.
上级 e6d21795
...@@ -1063,7 +1063,7 @@ def _moments(x, axes=None, shift=None, keep_dims=False): ...@@ -1063,7 +1063,7 @@ def _moments(x, axes=None, shift=None, keep_dims=False):
return mean, variance return mean, variance
def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3): def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
# The mean / var / beta / gamma may be processed by broadcast # The mean / var / beta / gamma may be processed by broadcast
# so it may have an extra batch axis with 1, it is not needed # so it may have an extra batch axis with 1, it is not needed
# in cntk, need to remove those dummy axis. # in cntk, need to remove those dummy axis.
......
...@@ -1850,7 +1850,7 @@ def normalize_batch_in_training(x, gamma, beta, ...@@ -1850,7 +1850,7 @@ def normalize_batch_in_training(x, gamma, beta,
epsilon=epsilon) epsilon=epsilon)
def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3): def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Applies batch normalization on x given mean, var, beta and gamma. """Applies batch normalization on x given mean, var, beta and gamma.
I.e. returns: I.e. returns:
...@@ -1862,11 +1862,39 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3): ...@@ -1862,11 +1862,39 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
var: Variance of batch. var: Variance of batch.
beta: Tensor with which to center the input. beta: Tensor with which to center the input.
gamma: Tensor by which to scale the input. gamma: Tensor by which to scale the input.
axis: Integer, the axis that should be normalized.
(typically the features axis).
epsilon: Fuzz factor. epsilon: Fuzz factor.
# Returns # Returns
A tensor. A tensor.
""" """
if ndim(x) == 4:
# The CPU implementation of FusedBatchNorm only support NHWC
if axis == 1:
tf_data_format = 'NCHW'
elif axis == 3:
tf_data_format = 'NHWC'
else:
tf_data_format = None
if tf_data_format == 'NHWC' or tf_data_format == 'NCHW' and _has_nchw_support():
if beta is None:
beta = zeros_like(mean)
if gamma is None:
gamma = ones_like(mean)
y, _, _ = tf.nn.fused_batch_norm(
x,
gamma,
beta,
epsilon=epsilon,
mean=mean,
variance=var,
data_format=tf_data_format,
is_training=False
)
return y
# default
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon) return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
......
...@@ -731,7 +731,7 @@ def normalize_batch_in_training(x, gamma, beta, ...@@ -731,7 +731,7 @@ def normalize_batch_in_training(x, gamma, beta,
return normed, mean, T.inv(stdinv ** 2) return normed, mean, T.inv(stdinv ** 2)
def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3): def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Apply batch normalization on x given mean, var, beta and gamma. """Apply batch normalization on x given mean, var, beta and gamma.
""" """
# TODO remove this if statement when Theano without # TODO remove this if statement when Theano without
......
...@@ -161,6 +161,7 @@ class BatchNormalization(Layer): ...@@ -161,6 +161,7 @@ class BatchNormalization(Layer):
broadcast_moving_variance, broadcast_moving_variance,
broadcast_beta, broadcast_beta,
broadcast_gamma, broadcast_gamma,
axis=self.axis,
epsilon=self.epsilon) epsilon=self.epsilon)
else: else:
return K.batch_normalization( return K.batch_normalization(
...@@ -169,6 +170,7 @@ class BatchNormalization(Layer): ...@@ -169,6 +170,7 @@ class BatchNormalization(Layer):
self.moving_variance, self.moving_variance,
self.beta, self.beta,
self.gamma, self.gamma,
axis=self.axis,
epsilon=self.epsilon) epsilon=self.epsilon)
# If the learning phase is *static* and set to inference: # If the learning phase is *static* and set to inference:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册