提交 cb96315a 编写于 作者: G Gabriel de Marmiesse 提交者: François Chollet

Added batch_normalization in the numpy backend. (#11556)

* Finished the functions.

* Started doing the test function.

* Added the batch_normalization operation to the numpy backend.

* forgot an argument
上级 1eac8612
......@@ -528,6 +528,10 @@ def print_tensor(x, message=''):
return x
def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=0.001):
return ((x - mean) / sqrt(var + epsilon)) * gamma + beta
def dot(x, y):
return np.dot(x, y)
......
......@@ -2332,6 +2332,8 @@ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
# Returns
A tensor.
{{np_implementation}}
"""
if ndim(x) == 4:
# The CPU implementation of FusedBatchNorm only support NHWC
......
......@@ -1641,6 +1641,34 @@ class TestBackend(object):
with pytest.raises(ValueError):
K.bias_add(x, b, data_format='channels_middle')
@pytest.mark.skipif(K.backend() == 'theano',
reason='Theano behaves differently '
'because of the broadcast.')
@pytest.mark.parametrize('axis', [1, -1])
@pytest.mark.parametrize('x_shape', [(3, 2, 4, 5), (3, 2, 4)])
def test_batch_normalization(self, axis, x_shape):
other_shape = [1] * len(x_shape)
other_shape[axis] = x_shape[axis]
other_shape = tuple(other_shape)
x_np = np.random.random(x_shape)
mean_np = np.random.random(other_shape)
var_np = np.random.random(other_shape)
beta_np = np.random.random(other_shape)
gamma_np = np.random.random(other_shape)
output_tensors = []
output_arrays = []
for k in WITH_NP:
x = k.variable(x_np)
mean = k.variable(mean_np)
var = k.variable(var_np)
beta = k.variable(beta_np)
gamma = k.variable(gamma_np)
output = k.batch_normalization(x, mean, var, beta, gamma, axis=axis)
output_tensors.append(output)
output_arrays.append(k.eval(output))
assert_list_pairwise(output_arrays)
assert_list_keras_shape(output_tensors, output_arrays)
@pytest.mark.skipif(K.backend() != 'theano',
reason='Specific to Theano.')
@pytest.mark.parametrize('x_shape', [(1, 4, 2, 3), (1, 2, 3, 4)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册