提交 cff53408 编写于 作者: N Nathan Silberman 提交者: TensorFlower Gardener

Fixing bug in which the 'trainable' argument wasn't being passed to the bias...

Fixing bug in which the 'trainable' argument wasn't being passed to the bias in convolution2d_transpose.
Change: 139925626
上级 824c9f11
...@@ -220,14 +220,16 @@ def _fused_batch_norm( ...@@ -220,14 +220,16 @@ def _fused_batch_norm(
if original_rank is None: if original_rank is None:
raise ValueError('Inputs %s has undefined rank' % inputs.name) raise ValueError('Inputs %s has undefined rank' % inputs.name)
elif original_rank not in [2, 4]: elif original_rank not in [2, 4]:
raise ValueError('Inputs %s has unsupported rank. \ raise ValueError('Inputs %s has unsupported rank.'
Expected 2 or 4 but got %d' % (inputs.name, original_rank)) ' Expected 2 or 4 but got %d' % (
inputs.name, original_rank))
if original_rank == 2: if original_rank == 2:
channels = inputs.get_shape()[-1].value channels = inputs.get_shape()[-1].value
if channels is None: if channels is None:
raise ValueError('`C` dimension must be known but is None') raise ValueError('`C` dimension must be known but is None')
new_shape = [-1, channels, 1, 1] if data_format == DATA_FORMAT_NCHW else \ new_shape = [-1, 1, 1, channels]
[-1, 1, 1, channels] if data_format == DATA_FORMAT_NCHW:
new_shape = [-1, channels, 1, 1]
inputs = array_ops.reshape(inputs, new_shape) inputs = array_ops.reshape(inputs, new_shape)
inputs_shape = inputs.get_shape() inputs_shape = inputs.get_shape()
dtype = inputs.dtype.base_dtype dtype = inputs.dtype.base_dtype
...@@ -316,7 +318,7 @@ def _fused_batch_norm( ...@@ -316,7 +318,7 @@ def _fused_batch_norm(
need_updates = is_training_value is None or is_training_value need_updates = is_training_value is None or is_training_value
if need_updates: if need_updates:
if updates_collections is None: if updates_collections is None:
_no_updates = lambda: outputs no_updates = lambda: outputs
def _force_updates(): def _force_updates():
"""Internal function forces updates moving_vars if is_training.""" """Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average( update_moving_mean = moving_averages.assign_moving_average(
...@@ -326,7 +328,7 @@ def _fused_batch_norm( ...@@ -326,7 +328,7 @@ def _fused_batch_norm(
with ops.control_dependencies( with ops.control_dependencies(
[update_moving_mean, update_moving_variance]): [update_moving_mean, update_moving_variance]):
return array_ops.identity(outputs) return array_ops.identity(outputs)
outputs = utils.smart_cond(is_training, _force_updates, _no_updates) outputs = utils.smart_cond(is_training, _force_updates, no_updates)
else: else:
moving_vars_fn = lambda: (moving_mean, moving_variance) moving_vars_fn = lambda: (moving_mean, moving_variance)
def _delay_updates(): def _delay_updates():
...@@ -684,7 +686,7 @@ def bias_add(inputs, ...@@ -684,7 +686,7 @@ def bias_add(inputs,
raise ValueError('Dims of shape must be known but is None') raise ValueError('Dims of shape must be known but is None')
elif inputs_rank != 4 and data_format == DATA_FORMAT_NCHW: elif inputs_rank != 4 and data_format == DATA_FORMAT_NCHW:
raise ValueError('Data format NCHW only supports 4D Tensor') raise ValueError('Data format NCHW only supports 4D Tensor')
axis = 1 if data_format==DATA_FORMAT_NCHW else -1 axis = 1 if data_format == DATA_FORMAT_NCHW else -1
num_features = inputs_shape[axis].value num_features = inputs_shape[axis].value
if num_features is None: if num_features is None:
raise ValueError('`C` dimension must be known but is None') raise ValueError('`C` dimension must be known but is None')
...@@ -1081,7 +1083,6 @@ def convolution2d_transpose( ...@@ -1081,7 +1083,6 @@ def convolution2d_transpose(
output_shape = [batch_size, num_outputs, out_height, out_width] output_shape = [batch_size, num_outputs, out_height, out_width]
strides = [1, 1, stride_h, stride_w] strides = [1, 1, stride_h, stride_w]
output_shape = array_ops.pack(output_shape) output_shape = array_ops.pack(output_shape)
outputs = nn.conv2d_transpose(inputs, weights, output_shape, outputs = nn.conv2d_transpose(inputs, weights, output_shape,
strides, strides,
...@@ -1091,8 +1092,10 @@ def convolution2d_transpose( ...@@ -1091,8 +1092,10 @@ def convolution2d_transpose(
# Infer the static output shape: # Infer the static output shape:
out_shape = inputs.get_shape().as_list() out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = num_outputs out_shape[c_axis] = num_outputs
out_shape[h_axis] = get_deconv_dim(out_shape[h_axis], stride_h, kernel_h, padding) out_shape[h_axis] = get_deconv_dim(
out_shape[w_axis] = get_deconv_dim(out_shape[w_axis], stride_w, kernel_w, padding) out_shape[h_axis], stride_h, kernel_h, padding)
out_shape[w_axis] = get_deconv_dim(
out_shape[w_axis], stride_w, kernel_w, padding)
outputs.set_shape(out_shape) outputs.set_shape(out_shape)
if normalizer_fn is not None: if normalizer_fn is not None:
...@@ -1107,6 +1110,7 @@ def convolution2d_transpose( ...@@ -1107,6 +1110,7 @@ def convolution2d_transpose(
dtype=dtype, dtype=dtype,
initializer=biases_initializer, initializer=biases_initializer,
regularizer=biases_regularizer, regularizer=biases_regularizer,
trainable=trainable,
collections=biases_collections) collections=biases_collections)
outputs = nn.bias_add(outputs, biases, data_format=data_format) outputs = nn.bias_add(outputs, biases, data_format=data_format)
......
...@@ -23,9 +23,8 @@ import numpy as np ...@@ -23,9 +23,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
# TODO(sguada) Expose tf.with_dependencies # TODO(sguada) Expose tf.with_dependencies
from tensorflow.python.ops import control_flow_ops
from tensorflow.contrib.layers.python.layers import layers as _layers from tensorflow.contrib.layers.python.layers import layers as _layers
from tensorflow.python.ops import state_ops from tensorflow.python.ops import control_flow_ops
class AvgPool2DTest(tf.test.TestCase): class AvgPool2DTest(tf.test.TestCase):
...@@ -181,7 +180,8 @@ class PoolTest(tf.test.TestCase): ...@@ -181,7 +180,8 @@ class PoolTest(tf.test.TestCase):
height, width = 5, 8 height, width = 5, 8
images = tf.random_uniform((5, 3, height, width), seed=1) images = tf.random_uniform((5, 3, height, width), seed=1)
output = tf.contrib.layers.pool( output = tf.contrib.layers.pool(
images, [2, 3], dilation_rate=[1, 2], pooling_type='AVG', data_format='NCHW') images, [2, 3], dilation_rate=[1, 2], pooling_type='AVG',
data_format='NCHW')
self.assertEqual(output.get_shape().as_list(), [5, 3, 4, 4]) self.assertEqual(output.get_shape().as_list(), [5, 3, 4, 4])
...@@ -370,7 +370,7 @@ class ConvolutionTest(tf.test.TestCase): ...@@ -370,7 +370,7 @@ class ConvolutionTest(tf.test.TestCase):
tf.contrib.framework.get_variables_by_name('weights')[0]) tf.contrib.framework.get_variables_by_name('weights')[0])
wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0] wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(wd.op.name, self.assertEqual(wd.op.name,
'Conv/weights/Regularizer/l2_regularizer') 'Conv/weights/Regularizer/l2_regularizer')
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
self.assertAlmostEqual(sess.run(wd), weight_decay * l2_loss.eval()) self.assertAlmostEqual(sess.run(wd), weight_decay * l2_loss.eval())
...@@ -588,6 +588,20 @@ class ConvolutionTest(tf.test.TestCase): ...@@ -588,6 +588,20 @@ class ConvolutionTest(tf.test.TestCase):
class Convolution2dTransposeTests(tf.test.TestCase): class Convolution2dTransposeTests(tf.test.TestCase):
def testTrainableFlagIsPassedOn(self):
for trainable in [True, False]:
with tf.Graph().as_default():
num_filters = 32
input_size = [5, 10, 12, 3]
images = tf.random_uniform(input_size, seed=1)
tf.contrib.layers.conv2d_transpose(
images, num_filters, [3, 3], stride=1, trainable=trainable)
model_variables = tf.contrib.framework.get_model_variables()
trainable_variables = tf.trainable_variables()
for model_variable in model_variables:
self.assertEqual(trainable, model_variable in trainable_variables)
def testInvalidDataFormat(self): def testInvalidDataFormat(self):
height, width = 7, 9 height, width = 7, 9
with self.test_session(): with self.test_session():
...@@ -597,7 +611,6 @@ class Convolution2dTransposeTests(tf.test.TestCase): ...@@ -597,7 +611,6 @@ class Convolution2dTransposeTests(tf.test.TestCase):
tf.contrib.layers.convolution2d_transpose( tf.contrib.layers.convolution2d_transpose(
images, 32, 3, data_format='CHWN') images, 32, 3, data_format='CHWN')
def testOutputSizeWithStrideOneSamePaddingNCHW(self): def testOutputSizeWithStrideOneSamePaddingNCHW(self):
# `NCHW` data fomat is only supported for `GPU` device. # `NCHW` data fomat is only supported for `GPU` device.
if tf.test.is_gpu_available(): if tf.test.is_gpu_available():
...@@ -615,7 +628,6 @@ class Convolution2dTransposeTests(tf.test.TestCase): ...@@ -615,7 +628,6 @@ class Convolution2dTransposeTests(tf.test.TestCase):
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size) self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithStrideOneValidPaddingNCHW(self): def testOutputSizeWithStrideOneValidPaddingNCHW(self):
if tf.test.is_gpu_available(): if tf.test.is_gpu_available():
with self.test_session(use_gpu=True) as sess: with self.test_session(use_gpu=True) as sess:
...@@ -756,7 +768,6 @@ class Convolution2dTransposeTests(tf.test.TestCase): ...@@ -756,7 +768,6 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size) self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithStrideOneSamePadding(self): def testOutputSizeWithStrideOneSamePadding(self):
num_filters = 32 num_filters = 32
input_size = [5, 10, 12, 3] input_size = [5, 10, 12, 3]
...@@ -1284,7 +1295,7 @@ class FlattenTest(tf.test.TestCase): ...@@ -1284,7 +1295,7 @@ class FlattenTest(tf.test.TestCase):
images = tf.random_uniform((5, height, width, 3), seed=1, name='images') images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
output = tf.contrib.layers.flatten(images) output = tf.contrib.layers.flatten(images)
self.assertEqual(output.get_shape().num_elements(), self.assertEqual(output.get_shape().num_elements(),
images.get_shape().num_elements()) images.get_shape().num_elements())
self.assertEqual(output.get_shape()[0], images.get_shape()[0]) self.assertEqual(output.get_shape()[0], images.get_shape()[0])
def testFlatten3D(self): def testFlatten3D(self):
...@@ -1293,7 +1304,7 @@ class FlattenTest(tf.test.TestCase): ...@@ -1293,7 +1304,7 @@ class FlattenTest(tf.test.TestCase):
images = tf.random_uniform((5, height, width), seed=1, name='images') images = tf.random_uniform((5, height, width), seed=1, name='images')
output = tf.contrib.layers.flatten(images) output = tf.contrib.layers.flatten(images)
self.assertEqual(output.get_shape().num_elements(), self.assertEqual(output.get_shape().num_elements(),
images.get_shape().num_elements()) images.get_shape().num_elements())
self.assertEqual(output.get_shape()[0], images.get_shape()[0]) self.assertEqual(output.get_shape()[0], images.get_shape()[0])
def testFlattenBatchSize(self): def testFlattenBatchSize(self):
...@@ -1303,10 +1314,10 @@ class FlattenTest(tf.test.TestCase): ...@@ -1303,10 +1314,10 @@ class FlattenTest(tf.test.TestCase):
inputs = tf.placeholder(tf.int32, (None, height, width, 3)) inputs = tf.placeholder(tf.int32, (None, height, width, 3))
output = tf.contrib.layers.flatten(inputs) output = tf.contrib.layers.flatten(inputs)
self.assertEqual(output.get_shape().as_list(), self.assertEqual(output.get_shape().as_list(),
[None, height * width * 3]) [None, height * width * 3])
output = sess.run(output, {inputs: images.eval()}) output = sess.run(output, {inputs: images.eval()})
self.assertEqual(output.size, self.assertEqual(output.size,
images.get_shape().num_elements()) images.get_shape().num_elements())
self.assertEqual(output.shape[0], images.get_shape()[0]) self.assertEqual(output.shape[0], images.get_shape()[0])
...@@ -1463,7 +1474,7 @@ class FCTest(tf.test.TestCase): ...@@ -1463,7 +1474,7 @@ class FCTest(tf.test.TestCase):
weights_regularizer=weight_decay) weights_regularizer=weight_decay)
wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0] wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(wd.op.name, self.assertEqual(wd.op.name,
'fully_connected/weights/Regularizer/l2_regularizer') 'fully_connected/weights/Regularizer/l2_regularizer')
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
self.assertLess(sess.run(wd), 0.4) self.assertLess(sess.run(wd), 0.4)
...@@ -1620,9 +1631,9 @@ class BatchNormTest(tf.test.TestCase): ...@@ -1620,9 +1631,9 @@ class BatchNormTest(tf.test.TestCase):
update_moving_mean = update_layers[0] update_moving_mean = update_layers[0]
update_moving_variance = update_layers[1] update_moving_variance = update_layers[1]
self.assertEqual(update_moving_mean.op.name, self.assertEqual(update_moving_mean.op.name,
'BatchNorm/AssignMovingAvg') 'BatchNorm/AssignMovingAvg')
self.assertEqual(update_moving_variance.op.name, self.assertEqual(update_moving_variance.op.name,
'BatchNorm/AssignMovingAvg_1') 'BatchNorm/AssignMovingAvg_1')
def testReuseVariables(self): def testReuseVariables(self):
height, width = 3, 3 height, width = 3, 3
...@@ -1774,8 +1785,8 @@ class BatchNormTest(tf.test.TestCase): ...@@ -1774,8 +1785,8 @@ class BatchNormTest(tf.test.TestCase):
if fused: if fused:
# Add Bessel's correction # Add Bessel's correction
moving_variance_corrected = moving_variance / correction_factor moving_variance_corrected = moving_variance / correction_factor
correct_moving_variance = state_ops.assign(moving_variance, correct_moving_variance = tf.assign(moving_variance,
moving_variance_corrected) moving_variance_corrected)
sess.run(correct_moving_variance) sess.run(correct_moving_variance)
self.assertAllClose(variance, expected_var) self.assertAllClose(variance, expected_var)
...@@ -1888,8 +1899,8 @@ class BatchNormTest(tf.test.TestCase): ...@@ -1888,8 +1899,8 @@ class BatchNormTest(tf.test.TestCase):
if fused: if fused:
# Add Bessel's correction # Add Bessel's correction
moving_variance_corrected = moving_variance / correction_factor moving_variance_corrected = moving_variance / correction_factor
correct_moving_variance = state_ops.assign(moving_variance, correct_moving_variance = tf.assign(moving_variance,
moving_variance_corrected) moving_variance_corrected)
sess.run(correct_moving_variance) sess.run(correct_moving_variance)
self.assertAllClose(variance, expected_var) self.assertAllClose(variance, expected_var)
# After convergence output_train and output_eval should be the same. # After convergence output_train and output_eval should be the same.
...@@ -1961,8 +1972,8 @@ class BatchNormTest(tf.test.TestCase): ...@@ -1961,8 +1972,8 @@ class BatchNormTest(tf.test.TestCase):
if fused: if fused:
# Add Bessel's correction # Add Bessel's correction
moving_variance_corrected = moving_variance / correction_factor moving_variance_corrected = moving_variance / correction_factor
correct_moving_variance = state_ops.assign(moving_variance, correct_moving_variance = tf.assign(moving_variance,
moving_variance_corrected) moving_variance_corrected)
sess.run(correct_moving_variance) sess.run(correct_moving_variance)
output_false = sess.run([output], {is_training: False}) output_false = sess.run([output], {is_training: False})
self.assertAllClose(output_true, output_false) self.assertAllClose(output_true, output_false)
...@@ -2100,8 +2111,8 @@ class BatchNormTest(tf.test.TestCase): ...@@ -2100,8 +2111,8 @@ class BatchNormTest(tf.test.TestCase):
if fused: if fused:
# Add Bessel's correction # Add Bessel's correction
moving_variance_corrected = moving_variance / correction_factor moving_variance_corrected = moving_variance / correction_factor
correct_moving_variance = state_ops.assign(moving_variance, correct_moving_variance = tf.assign(moving_variance,
moving_variance_corrected) moving_variance_corrected)
sess.run(correct_moving_variance) sess.run(correct_moving_variance)
output_false = sess.run([output], {is_training: False}) output_false = sess.run([output], {is_training: False})
self.assertTrue(np.allclose(output_true, output_false)) self.assertTrue(np.allclose(output_true, output_false))
...@@ -2212,10 +2223,10 @@ class BatchNormTest(tf.test.TestCase): ...@@ -2212,10 +2223,10 @@ class BatchNormTest(tf.test.TestCase):
scale=True, scale=True,
epsilon=0.0, epsilon=0.0,
param_initializers={ param_initializers={
'beta': beta, 'beta': beta,
'gamma': gamma, 'gamma': gamma,
'moving_mean': mean, 'moving_mean': mean,
'moving_variance': variance, 'moving_variance': variance,
}) })
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
outs = sess.run(output) outs = sess.run(output)
...@@ -2358,6 +2369,7 @@ class LayerNormTest(tf.test.TestCase): ...@@ -2358,6 +2369,7 @@ class LayerNormTest(tf.test.TestCase):
def testOutput4DInput(self): def testOutput4DInput(self):
self.doOutputTest((100, 10, 10, 3)) self.doOutputTest((100, 10, 10, 3))
class MaxPool2DTest(tf.test.TestCase): class MaxPool2DTest(tf.test.TestCase):
def testInvalidDataFormat(self): def testInvalidDataFormat(self):
...@@ -2974,7 +2986,7 @@ class LegacyFullyConnectedTest(tf.test.TestCase): ...@@ -2974,7 +2986,7 @@ class LegacyFullyConnectedTest(tf.test.TestCase):
self.assertEqual(1, len(tf.get_collection('unbiased'))) self.assertEqual(1, len(tf.get_collection('unbiased')))
self.assertEqual(1, len(tf.get_collection('biased'))) self.assertEqual(1, len(tf.get_collection('biased')))
self.assertEqual(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), self.assertEqual(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
tf.get_collection('all')) tf.get_collection('all'))
def test_no_bias(self): def test_no_bias(self):
tf.contrib.layers.legacy_relu(self.input, 2, bias_init=None) tf.contrib.layers.legacy_relu(self.input, 2, bias_init=None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册