提交 9c689775 编写于 作者: F Francois Chollet

Use NCW/NWC for conv1d data format in TF backend.

上级 895dba6c
......@@ -3325,12 +3325,12 @@ def _preprocess_conv1d_input(x, data_format):
"""
if dtype(x) == 'float64':
x = tf.cast(x, 'float32')
tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
tf_data_format = 'NWC' # to pass TF Conv2dNative operations
if data_format == 'channels_first':
if not _has_nchw_support():
x = tf.transpose(x, (0, 2, 1)) # NCW -> NWC
else:
tf_data_format = 'NCHW'
tf_data_format = 'NCW'
return x, tf_data_format
......@@ -3441,7 +3441,7 @@ def conv1d(x, kernel, strides=1, padding='valid',
padding=padding,
data_format=tf_data_format)
if data_format == 'channels_first' and tf_data_format in {'NWC', 'NHWC'}:
if data_format == 'channels_first' and tf_data_format == 'NWC':
x = tf.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
......@@ -3571,6 +3571,10 @@ def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
dilation_rate = (dilation_rate,)
x, tf_data_format = _preprocess_conv1d_input(x, data_format)
if tf_data_format == 'NWC':
tf_data_format = 'NHWC'
else:
tf_data_format = 'NCHW'
padding = _preprocess_padding(padding)
if tf_data_format == 'NHWC':
spatial_start_dim = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册