提交 d40656e5 编写于 作者: G Gabriel de Marmiesse 提交者: François Chollet

Refactoring: Used the function `normalize_data_format` to remove some code. (#10690)

* Used the function `normalize_data_format` to remove some code. Moved the implementation to keras/backend/common.py.

* Changed all conv_utils.normalize_data_format into K.normalize_data_format
上级 d2803c0f
......@@ -99,6 +99,7 @@ EXCLUDE = {
'deserialize',
'get',
'set_image_dim_ordering',
'normalize_data_format',
'image_dim_ordering',
'get_variable_shape',
}
......
......@@ -11,6 +11,7 @@ from .common import set_floatx
from .common import cast_to_floatx
from .common import image_data_format
from .common import set_image_data_format
from .common import normalize_data_format
# Set Keras base dir path given KERAS_HOME env variable, if applicable.
# Otherwise either ~/.keras or /tmp.
......
......@@ -4,7 +4,10 @@ from __future__ import print_function
import cntk as C
import numpy as np
from .common import floatx, epsilon, image_dim_ordering, image_data_format
from .common import floatx
from .common import epsilon
from .common import image_data_format
from .common import normalize_data_format
from collections import defaultdict
from contextlib import contextmanager
import warnings
......@@ -184,10 +187,7 @@ def variable(value, dtype=None, name=None, constraint=None):
def bias_add(x, bias, data_format=None):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
dims = len(x.shape)
if dims > 0 and x.shape[0] == C.InferredDimension:
......@@ -1480,10 +1480,7 @@ def hard_sigmoid(x):
def conv1d(x, kernel, strides=1, padding='valid',
data_format=None, dilation_rate=1):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
if padding == 'causal':
# causal (dilated) convolution:
......@@ -1512,10 +1509,7 @@ def conv1d(x, kernel, strides=1, padding='valid',
def conv2d(x, kernel, strides=(1, 1), padding='valid',
data_format=None, dilation_rate=(1, 1)):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
x = _preprocess_conv2d_input(x, data_format)
kernel = _preprocess_conv2d_kernel(kernel, data_format)
......@@ -1546,10 +1540,7 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid',
def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
padding='valid', data_format=None, dilation_rate=1):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
if isinstance(strides, int):
strides = (strides,)
if isinstance(dilation_rate, int):
......@@ -1599,10 +1590,7 @@ def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
padding='valid', data_format=None, dilation_rate=(1, 1)):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
x = _preprocess_conv2d_input(x, data_format)
depthwise_kernel = _preprocess_conv2d_kernel(depthwise_kernel, data_format)
......@@ -1637,10 +1625,7 @@ def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
def depthwise_conv2d(x, depthwise_kernel, strides=(1, 1), padding='valid',
data_format=None, dilation_rate=(1, 1)):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
x = _preprocess_conv2d_input(x, data_format)
depthwise_kernel = _preprocess_conv2d_kernel(depthwise_kernel, data_format)
......@@ -1668,10 +1653,7 @@ def depthwise_conv2d(x, depthwise_kernel, strides=(1, 1), padding='valid',
def conv3d(x, kernel, strides=(1, 1, 1), padding='valid',
data_format=None, dilation_rate=(1, 1, 1)):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
x = _preprocess_conv3d_input(x, data_format)
kernel = _preprocess_conv3d_kernel(kernel, data_format)
......@@ -1692,10 +1674,7 @@ def conv3d(x, kernel, strides=(1, 1, 1), padding='valid',
def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
padding='valid', data_format=None):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
x = _preprocess_conv3d_input(x, data_format)
kernel = _preprocess_conv3d_kernel(kernel, data_format)
......@@ -1728,10 +1707,7 @@ def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
def pool2d(x, pool_size, strides=(1, 1),
padding='valid', data_format=None,
pool_mode='max'):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
padding = _preprocess_border_mode(padding)
strides = strides
......@@ -1758,10 +1734,7 @@ def pool2d(x, pool_size, strides=(1, 1),
def pool3d(x, pool_size, strides=(1, 1, 1), padding='valid',
data_format=None, pool_mode='max'):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
padding = _preprocess_border_mode(padding)
......@@ -2101,10 +2074,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
assert len(padding) == 2
assert len(padding[0]) == 2
assert len(padding[1]) == 2
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
num_dynamic_axis = _get_dynamic_axis_num(x)
assert len(x.shape) == 4 - (1 if num_dynamic_axis > 0 else 0)
......@@ -2116,10 +2086,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
assert len(padding[0]) == 2
assert len(padding[1]) == 2
assert len(padding[2]) == 2
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
num_dynamic_axis = _get_dynamic_axis_num(x)
assert len(x.shape) == 5 - (1 if num_dynamic_axis > 0 else 0)
......@@ -2224,10 +2191,7 @@ def in_top_k(predictions, targets, k):
def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
padding='valid', data_format=None):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
x = _preprocess_conv2d_input(x, data_format)
kernel = _preprocess_conv2d_kernel(kernel, data_format)
......@@ -2357,10 +2321,7 @@ def _reshape_sequence(x, time_step):
def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
stride = strides[0]
kernel_shape = int_shape(kernel)
......@@ -2389,10 +2350,7 @@ def local_conv2d(inputs,
strides,
output_shape,
data_format=None):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
stride_row, stride_col = strides
output_row, output_col = output_shape
......
......@@ -147,6 +147,37 @@ def set_image_data_format(data_format):
_IMAGE_DATA_FORMAT = str(data_format)
def normalize_data_format(value):
"""Checks that the value correspond to a valid data format.
# Arguments
value: String or None. `'channels_first'` or `'channels_last'`.
# Returns
A string, either `'channels_first'` or `'channels_last'`
# Example
```python
>>> from keras import backend as K
>>> K.normalize_data_format(None)
'channels_first'
>>> K.normalize_data_format('channels_last')
'channels_last'
```
# Raises
ValueError: if `value` or the global `data_format` invalid.
"""
if value is None:
value = image_data_format()
data_format = value.lower()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: ' +
str(value))
return data_format
# Legacy methods
def set_image_dim_ordering(dim_ordering):
......
......@@ -17,8 +17,9 @@ from collections import defaultdict
import numpy as np
import os
from .common import floatx, epsilon
from .common import image_data_format
from .common import floatx
from .common import epsilon
from .common import normalize_data_format
from ..utils.generic_utils import has_arg
# Legacy functions
......@@ -2235,10 +2236,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
assert len(padding) == 2
assert len(padding[0]) == 2
assert len(padding[1]) == 2
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
if data_format == 'channels_first':
pattern = [[0, 0],
......@@ -2279,10 +2277,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
assert len(padding[0]) == 2
assert len(padding[1]) == 2
assert len(padding[2]) == 2
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
if data_format == 'channels_first':
pattern = [
......@@ -3508,10 +3503,7 @@ def conv1d(x, kernel, strides=1, padding='valid',
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
kernel_shape = kernel.get_shape().as_list()
if padding == 'causal':
......@@ -3559,10 +3551,7 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid',
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
......@@ -3601,10 +3590,7 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
if isinstance(output_shape, (tuple, list)):
output_shape = tf.stack(output_shape)
......@@ -3653,10 +3639,7 @@ def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
if isinstance(strides, int):
strides = (strides,)
if isinstance(dilation_rate, int):
......@@ -3714,10 +3697,7 @@ def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
padding = _preprocess_padding(padding)
......@@ -3756,10 +3736,7 @@ def depthwise_conv2d(x, depthwise_kernel, strides=(1, 1), padding='valid',
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
padding = _preprocess_padding(padding)
......@@ -3799,10 +3776,7 @@ def conv3d(x, kernel, strides=(1, 1, 1), padding='valid',
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
x, tf_data_format = _preprocess_conv3d_input(x, data_format)
padding = _preprocess_padding(padding)
......@@ -3839,10 +3813,7 @@ def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
if isinstance(output_shape, (tuple, list)):
output_shape = tf.stack(output_shape)
......@@ -3892,10 +3863,7 @@ def pool2d(x, pool_size, strides=(1, 1),
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
padding = _preprocess_padding(padding)
......@@ -3941,10 +3909,7 @@ def pool3d(x, pool_size, strides=(1, 1, 1), padding='valid',
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
x, tf_data_format = _preprocess_conv3d_input(x, data_format)
padding = _preprocess_padding(padding)
......@@ -3989,10 +3954,7 @@ def bias_add(x, bias, data_format=None):
the bias should be either a vector or
a tensor with ndim(x) - 1 dimension
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
bias_shape = int_shape(bias)
if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
raise ValueError('Unexpected bias dimensions %d, expect to be 1 or %d dimensions'
......@@ -4321,10 +4283,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
ValueError: If `data_format` is neither
`"channels_last"` nor `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
stride = strides[0]
kernel_shape = int_shape(kernel)
......@@ -4373,10 +4332,7 @@ def local_conv2d(inputs, kernel, kernel_size, strides, output_shape, data_format
ValueError: if `data_format` is neither
`channels_last` or `channels_first`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
data_format = normalize_data_format(data_format)
stride_row, stride_col = strides
output_row, output_col = output_shape
......
......@@ -20,7 +20,9 @@ except ImportError:
from theano.sandbox.softsign import softsign as T_softsign
import numpy as np
from .common import floatx, epsilon, image_data_format
from .common import floatx
from .common import epsilon
from .common import normalize_data_format
from ..utils.generic_utils import has_arg
# Legacy functions
from .common import set_image_dim_ordering, image_dim_ordering
......@@ -1115,10 +1117,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
assert len(padding[1]) == 2
top_pad, bottom_pad = padding[0]
left_pad, right_pad = padding[1]
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
input_shape = x.shape
if data_format == 'channels_first':
......@@ -1151,10 +1150,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
"""Pad the 2nd, 3rd and 4th dimensions of a 5D tensor
with "padding[0]", "padding[1]" and "padding[2]" (resp.) zeros left and right.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
input_shape = x.shape
if data_format == 'channels_first':
......@@ -1933,10 +1929,7 @@ def conv1d(x, kernel, strides=1, padding='valid',
data_format: string, one of "channels_last", "channels_first"
dilation_rate: integer.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ', data_format)
data_format = normalize_data_format(data_format)
kernel_shape = int_shape(kernel)
if padding == 'causal':
......@@ -1990,10 +1983,7 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid',
Whether to use Theano or TensorFlow data format
in inputs/kernels/outputs.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ', data_format)
data_format = normalize_data_format(data_format)
image_shape = _preprocess_conv2d_image_shape(int_shape(x), data_format)
kernel_shape = int_shape(kernel)
......@@ -2033,10 +2023,7 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
ValueError: if using an even kernel size with padding 'same'.
"""
flip_filters = False
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + data_format)
data_format = normalize_data_format(data_format)
if data_format == 'channels_last':
output_shape = (output_shape[0],
......@@ -2089,10 +2076,7 @@ def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
# Raises
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ', data_format)
data_format = normalize_data_format(data_format)
if isinstance(strides, int):
strides = (strides,)
if isinstance(dilation_rate, int):
......@@ -2163,10 +2147,7 @@ def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
# Raises
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ', data_format)
data_format = normalize_data_format(data_format)
image_shape = _preprocess_conv2d_image_shape(int_shape(x), data_format)
depthwise_kernel_shape = int_shape(depthwise_kernel)
......@@ -2221,10 +2202,7 @@ def depthwise_conv2d(x, depthwise_kernel, strides=(1, 1), padding='valid',
# Raises
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ', data_format)
data_format = normalize_data_format(data_format)
image_shape = _preprocess_conv2d_image_shape(int_shape(x), data_format)
depthwise_kernel_shape = int_shape(depthwise_kernel)
......@@ -2261,10 +2239,7 @@ def conv3d(x, kernel, strides=(1, 1, 1),
Whether to use Theano or TensorFlow data format
in inputs/kernels/outputs.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format:', data_format)
data_format = normalize_data_format(data_format)
volume_shape = _preprocess_conv3d_volume_shape(int_shape(x), data_format)
kernel_shape = int_shape(kernel)
......@@ -2304,10 +2279,7 @@ def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
ValueError: if using an even kernel size with padding 'same'.
"""
flip_filters = False
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + data_format)
data_format = normalize_data_format(data_format)
if data_format == 'channels_last':
output_shape = (output_shape[0],
......@@ -2344,10 +2316,7 @@ def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
def pool2d(x, pool_size, strides=(1, 1), padding='valid',
data_format=None, pool_mode='max'):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format:', data_format)
data_format = normalize_data_format(data_format)
assert pool_size[0] >= 1 and pool_size[1] >= 1
......@@ -2389,10 +2358,7 @@ def pool2d(x, pool_size, strides=(1, 1), padding='valid',
def pool3d(x, pool_size, strides=(1, 1, 1), padding='valid',
data_format=None, pool_mode='max'):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format:', data_format)
data_format = normalize_data_format(data_format)
if padding == 'same':
w_pad = pool_size[0] - 2 if pool_size[0] % 2 == 1 else pool_size[0] - 1
......@@ -2436,10 +2402,7 @@ def pool3d(x, pool_size, strides=(1, 1, 1), padding='valid',
def bias_add(x, bias, data_format=None):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
if ndim(bias) != 1 and ndim(bias) != ndim(x) - 1:
raise ValueError('Unexpected bias dimensions %d, '
'expect to be 1 or %d dimensions'
......@@ -2701,10 +2664,7 @@ def foldr(fn, elems, initializer=None, name=None):
def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
stride = strides[0]
kernel_shape = int_shape(kernel)
......@@ -2723,10 +2683,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
def local_conv2d(inputs, kernel, kernel_size, strides, output_shape, data_format=None):
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
data_format = normalize_data_format(data_format)
stride_row, stride_col = strides
output_row, output_col = output_shape
......
......@@ -107,7 +107,7 @@ class _Conv(Layer):
self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, 'dilation_rate')
self.activation = activations.get(activation)
self.use_bias = use_bias
......@@ -1934,7 +1934,7 @@ class UpSampling2D(Layer):
@interfaces.legacy_upsampling2d_support
def __init__(self, size=(2, 2), data_format=None, **kwargs):
super(UpSampling2D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 2, 'size')
self.input_spec = InputSpec(ndim=4)
......@@ -2002,7 +2002,7 @@ class UpSampling3D(Layer):
@interfaces.legacy_upsampling3d_support
def __init__(self, size=(2, 2, 2), data_format=None, **kwargs):
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 3, 'size')
self.input_spec = InputSpec(ndim=5)
super(UpSampling3D, self).__init__(**kwargs)
......@@ -2130,7 +2130,7 @@ class ZeroPadding2D(Layer):
data_format=None,
**kwargs):
super(ZeroPadding2D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
if isinstance(padding, int):
self.padding = ((padding, padding), (padding, padding))
elif hasattr(padding, '__len__'):
......@@ -2234,7 +2234,7 @@ class ZeroPadding3D(Layer):
@interfaces.legacy_zeropadding3d_support
def __init__(self, padding=(1, 1, 1), data_format=None, **kwargs):
super(ZeroPadding3D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
if isinstance(padding, int):
self.padding = ((padding, padding), (padding, padding), (padding, padding))
elif hasattr(padding, '__len__'):
......@@ -2413,7 +2413,7 @@ class Cropping2D(Layer):
def __init__(self, cropping=((0, 0), (0, 0)),
data_format=None, **kwargs):
super(Cropping2D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
if isinstance(cropping, int):
self.cropping = ((cropping, cropping), (cropping, cropping))
elif hasattr(cropping, '__len__'):
......@@ -2541,7 +2541,7 @@ class Cropping3D(Layer):
def __init__(self, cropping=((1, 1), (1, 1), (1, 1)),
data_format=None, **kwargs):
super(Cropping3D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
if isinstance(cropping, int):
self.cropping = ((cropping, cropping),
(cropping, cropping),
......
......@@ -563,7 +563,7 @@ class ConvLSTM2DCell(Layer):
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate')
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
......
......@@ -209,12 +209,7 @@ class SpatialDropout2D(Dropout):
@interfaces.legacy_spatialdropoutNd_support
def __init__(self, rate, data_format=None, **kwargs):
super(SpatialDropout2D, self).__init__(rate, **kwargs)
if data_format is None:
data_format = K.image_data_format()
if data_format not in {'channels_last', 'channels_first'}:
raise ValueError('`data_format` must be in '
'{`"channels_last"`, `"channels_first"`}')
self.data_format = data_format
self.data_format = K.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
def _get_noise_shape(self, inputs):
......@@ -262,12 +257,7 @@ class SpatialDropout3D(Dropout):
@interfaces.legacy_spatialdropoutNd_support
def __init__(self, rate, data_format=None, **kwargs):
super(SpatialDropout3D, self).__init__(rate, **kwargs)
if data_format is None:
data_format = K.image_data_format()
if data_format not in {'channels_last', 'channels_first'}:
raise ValueError('`data_format` must be in '
'{`"channels_last"`, `"channels_first"`}')
self.data_format = data_format
self.data_format = K.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=5)
def _get_noise_shape(self, inputs):
......@@ -497,7 +487,7 @@ class Flatten(Layer):
def __init__(self, data_format=None, **kwargs):
super(Flatten, self).__init__(**kwargs)
self.input_spec = InputSpec(min_ndim=3)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
def compute_output_shape(self, input_shape):
if not all(input_shape[1:]):
......
......@@ -101,7 +101,7 @@ class LocallyConnected1D(Layer):
if self.padding != 'valid':
raise ValueError('Invalid border mode for LocallyConnected1D '
'(only "valid" is supported): ' + padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
......@@ -283,7 +283,7 @@ class LocallyConnected2D(Layer):
if self.padding != 'valid':
raise ValueError('Invalid border mode for LocallyConnected2D '
'(only "valid" is supported): ' + padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
......
......@@ -126,7 +126,7 @@ class _Pooling2D(Layer):
self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size')
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
......@@ -287,7 +287,7 @@ class _Pooling3D(Layer):
self.pool_size = conv_utils.normalize_tuple(pool_size, 3, 'pool_size')
self.strides = conv_utils.normalize_tuple(strides, 3, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=5)
def compute_output_shape(self, input_shape):
......@@ -488,7 +488,7 @@ class _GlobalPooling2D(Layer):
@interfaces.legacy_global_pooling_support
def __init__(self, data_format=None, **kwargs):
super(_GlobalPooling2D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
......@@ -583,7 +583,7 @@ class _GlobalPooling3D(Layer):
@interfaces.legacy_global_pooling_support
def __init__(self, data_format=None, **kwargs):
super(_GlobalPooling3D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=5)
def compute_output_shape(self, input_shape):
......
......@@ -748,7 +748,7 @@ class ConvRecurrent2D(Recurrent):
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate')
self.return_sequences = return_sequences
self.go_backwards = go_backwards
......
......@@ -48,17 +48,6 @@ def normalize_tuple(value, n, name):
return value_tuple
def normalize_data_format(value):
if value is None:
value = K.image_data_format()
data_format = value.lower()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: ' +
str(value))
return data_format
def normalize_padding(value):
padding = value.lower()
allowed = {'valid', 'same', 'causal'}
......
import pytest
import numpy as np
from keras.utils import conv_utils
from keras import backend as K
def test_normalize_tuple():
......@@ -17,7 +18,7 @@ def test_normalize_tuple():
def test_invalid_data_format():
with pytest.raises(ValueError):
conv_utils.normalize_data_format('channels_middle')
K.normalize_data_format('channels_middle')
def test_invalid_padding():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册