From 54d915b16dc8e906efb543a205c515c6d3fb9f3e Mon Sep 17 00:00:00 2001 From: Illia Polosukhin Date: Tue, 21 Jun 2016 00:20:47 -0800 Subject: [PATCH] Raising errors for humans: * optimize_loss expects 0d Tensor. * Estimator.get_eval_ops requires to have `targets` not None. * Adding asserts for shapes in rnn_cells. * Added error check in rnn.rnn. Change: 125425807 --- .../layers/python/layers/optimizers.py | 5 +- .../layers/python/layers/optimizers_test.py | 15 ++++++ .../python/learn/estimators/estimator.py | 3 ++ .../python/learn/estimators/estimator_test.py | 7 +++ .../python/kernel_tests/rnn_cell_test.py | 12 +++++ tensorflow/python/kernel_tests/rnn_test.py | 14 ++++++ tensorflow/python/ops/rnn.py | 16 +++++-- tensorflow/python/ops/rnn_cell.py | 46 ++++++++++++++++++- 8 files changed, 112 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index cfee1ec142f..94984af15e0 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -92,8 +92,11 @@ def optimize_loss(loss, Training op. Raises: - ValueError: if optimizer is wrong type. + ValueError: if `optimizer` is wrong type or `loss` is not 0d Tensor. """ + if not isinstance(loss, ops.Tensor) or loss.get_shape().ndims > 0: + raise ValueError("optimize_loss expects loss to be 0d Tensor, " + "got %s" % loss) with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"): # Update ops take UPDATE_OPS collection if not provided. update_ops = (set(update_ops or []) or diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 45440333dd1..4f21787ba4a 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -63,6 +63,21 @@ class OptimizersTest(tf.test.TestCase): learning_rate=0.1, optimizer=optimizer) + def testWrongLoss(self): + with tf.Graph().as_default() as g: + with self.test_session(graph=g): + _, _, _, global_step = _setup_model() + with self.assertRaises(ValueError): + tf.contrib.layers.optimize_loss(None, + global_step, + learning_rate=0.1, + optimizer="SGD") + with self.assertRaises(ValueError): + tf.contrib.layers.optimize_loss([[1.0]], + global_step, + learning_rate=0.1, + optimizer="SGD") + def testGradientNoise(self): tf.set_random_seed(42) with self.test_session() as session: diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 135afb04e5d..d4bc100974e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -729,6 +729,9 @@ class Estimator(BaseEstimator): Raises: ValueError: if `metrics` don't match `targets`. """ + if targets is None: + raise ValueError('Metrics %s in Estimator.evaluate requires targets ' + 'not be None.' % metrics) predictions, loss, _ = self._call_model_fn(features, targets, ModeKeys.EVAL) result = {'loss': loss} metrics = metrics or {} diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 856e61d2fff..aa514d415d0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -213,6 +213,13 @@ class EstimatorTest(tf.test.TestCase): est.fit(input_fn=boston_input_fn, steps=1) _ = est.evaluate(input_fn=boston_eval_fn, steps=1) + def testEvalNoTarget(self): + boston = tf.contrib.learn.datasets.load_boston() + est = tf.contrib.learn.Estimator(model_fn=linear_model_fn) + est.fit(input_fn=boston_input_fn, steps=1) + with self.assertRaises(ValueError): + _ = est.evaluate(x=boston.data, steps=1) + def testPredict(self): est = tf.contrib.learn.Estimator(model_fn=linear_model_fn) boston = tf.contrib.learn.datasets.load_boston() diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index 886c3b757aa..337dd0000b8 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -83,6 +83,18 @@ class RNNCellTest(tf.test.TestCase): # Smoke test self.assertAllClose(res[0], [[0.156736, 0.156736]]) + def testGRUCellMismatch(self): + with self.test_session(): + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + x_incorrect = tf.zeros([1, 2, 3]) + m = tf.zeros([1, 2]) + m_incorrect = tf.zeros([1, 2, 3]) + with self.assertRaises(ValueError): + tf.nn.rnn_cell.GRUCell(2)(x_incorrect, m) + with self.assertRaises(ValueError): + tf.nn.rnn_cell.GRUCell(2)(x, m_incorrect) + def testBasicLSTMCell(self): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 34b1c81e77b..117fb0ca570 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -1025,6 +1025,20 @@ class LSTMTest(tf.test.TestCase): self._testDynamicEquivalentToStaticRNN( use_gpu=True, use_sequence_length=True) + def testRNNWrongInputs(self): + cell = tf.nn.rnn_cell.GRUCell(2) + inputs = tf.constant([[1., 2.], [1., 2.]]) + inputs_wrong = tf.constant([[[1., 2.], [1., 2.]], [[1., 2.], [1., 2.]]]) + state_wrong = tf.constant([[1., 2.], [1., 2.]]) + with self.assertRaises(TypeError): + tf.nn.rnn(None, inputs) + with self.assertRaises(TypeError): + tf.nn.rnn(cell, None) + with self.assertRaises(TypeError): + tf.nn.rnn(cell, inputs_wrong) + with self.assertRaises(TypeError): + tf.nn.rnn(cell, inputs, state_wrong) + class BidirectionalRNNTest(tf.test.TestCase): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index d8a96db0a2b..345bd56dc6b 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -91,11 +91,19 @@ def rnn(cell, inputs, initial_state=None, dtype=None, """ if not isinstance(cell, rnn_cell.RNNCell): - raise TypeError("cell must be an instance of RNNCell") + raise TypeError("Rnn expects cell to be an instance of RNNCell, " + "got %s" % cell) if not isinstance(inputs, list): - raise TypeError("inputs must be a list") + raise TypeError("Rnn expects inputs to be a list, got %s" % inputs) if not inputs: - raise ValueError("inputs must not be empty") + raise ValueError("Rnn expects inputs not to be empty, got %s" % inputs) + first_el_shape = inputs[0].get_shape() + for i, inp in enumerate(inputs): + if not inp.get_shape().is_compatible_with(first_el_shape): + raise ValueError("Rnn expectes inputs to be list of compatible Tensors. " + "Received %s, where inputs[%d].shape == %s doesn't " + "match inputs[0].shape == %s" % ( + inputs, i, inp.get_shape(), first_el_shape)) outputs = [] # Create a new scope in which the caching device is either @@ -127,7 +135,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None, else: if not dtype: raise ValueError("If no initial_state is provided, " - "dtype must be specified") + "dtype must be specified") state = cell.zero_state(batch_size, dtype) if sequence_length is not None: # Prepare variables diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py index ee9c06dd132..dd87f5e6fe7 100644 --- a/tensorflow/python/ops/rnn_cell.py +++ b/tensorflow/python/ops/rnn_cell.py @@ -171,6 +171,30 @@ class RNNCell(object): return zeros +def _assert_rnn_cell_input_shapes(inputs, state, state_is_tuple): + inputs = ops.convert_to_tensor(inputs) + if inputs.get_shape().ndims is not None and inputs.get_shape().ndims != 2: + raise ValueError("RNN cells expects inputs to be 2D Tensor, " + "got inputs with %s shape." % inputs.get_shape()) + # If state is None, don't check. + if state is None: + return + if state_is_tuple: + c, h = state + c = ops.convert_to_tensor(c) + h = ops.convert_to_tensor(h) + if ((c.get_shape().ndims is not None and c.get_shape().ndims != 2) or + (h.get_shape().ndims is not None and h.get_shape().ndims != 2)): + raise ValueError("RNN cells expects state to be tuple of 2D Tensors, " + "got state tuple with shape: (%s, %s)." % ( + c.get_shape(), h.get_shape())) + else: + state = ops.convert_to_tensor(state) + if state.get_shape().ndims is not None and state.get_shape().ndims != 2: + raise ValueError("RNN cells expect state to be 2D Tensor, " + "got state with sahpe: %s." % state.get_shape()) + + class BasicRNNCell(RNNCell): """The most basic RNN cell.""" @@ -189,7 +213,19 @@ class BasicRNNCell(RNNCell): return self._num_units def __call__(self, inputs, state, scope=None): - """Most basic RNN: output = new_state = activation(W * input + U * state + B).""" + """Most basic RNN. + + output = new_state = activation(W * input + U * state + B). + + Args: + inputs: 2D `Tensor`. + state: `tuple` of `Tensor`s or 2D `Tensor`. + scope: name of `VariableScope` object. + + Returns: + `tuple` of output and state `Tensor`, where they are equal for this cell. + """ + _assert_rnn_cell_input_shapes(inputs, state, False) with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell" output = self._activation(_linear([inputs, state], self._num_units, True)) return output, output @@ -214,6 +250,7 @@ class GRUCell(RNNCell): def __call__(self, inputs, state, scope=None): """Gated recurrent unit (GRU) with nunits cells.""" + _assert_rnn_cell_input_shapes(inputs, state, False) with vs.variable_scope(scope or type(self).__name__): # "GRUCell" with vs.variable_scope("Gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. @@ -289,6 +326,7 @@ class BasicLSTMCell(RNNCell): def __call__(self, inputs, state, scope=None): """Long short-term memory cell (LSTM).""" + _assert_rnn_cell_input_shapes(inputs, state, self._state_is_tuple) with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: @@ -460,6 +498,8 @@ class LSTMCell(RNNCell): ValueError: If input size cannot be inferred from inputs via static shape inference. """ + _assert_rnn_cell_input_shapes(inputs, state, self._state_is_tuple) + num_proj = self._num_units if self._num_proj is None else self._num_proj if self._state_is_tuple: @@ -564,6 +604,7 @@ class OutputProjectionWrapper(RNNCell): """Run the cell and output projection on inputs, starting from state.""" output, res_state = self._cell(inputs, state) # Default scope: "OutputProjectionWrapper" + _assert_rnn_cell_input_shapes(inputs, None, None) with vs.variable_scope(scope or type(self).__name__): projected = _linear(output, self._output_size, True) return projected, res_state @@ -606,6 +647,7 @@ class InputProjectionWrapper(RNNCell): def __call__(self, inputs, state, scope=None): """Run the input projection and then the cell.""" # Default scope: "InputProjectionWrapper" + _assert_rnn_cell_input_shapes(inputs, None, None) with vs.variable_scope(scope or type(self).__name__): projected = _linear(inputs, self._num_proj, True) return self._cell(projected, state) @@ -657,6 +699,7 @@ class DropoutWrapper(RNNCell): def __call__(self, inputs, state, scope=None): """Run the cell with the declared dropouts.""" + _assert_rnn_cell_input_shapes(inputs, None, None) if (not isinstance(self._input_keep_prob, float) or self._input_keep_prob < 1): inputs = nn_ops.dropout(inputs, self._input_keep_prob, seed=self._seed) @@ -767,6 +810,7 @@ class MultiRNNCell(RNNCell): def __call__(self, inputs, state, scope=None): """Run this multi-layer cell on inputs, starting from state.""" + _assert_rnn_cell_input_shapes(inputs, state, self._state_is_tuple) with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell" cur_state_pos = 0 cur_inp = inputs -- GitLab