提交 da3a241a 编写于 作者: A Akshay Modi 提交者: TensorFlower Gardener

Nicer error message when watching things other than tensors/resource vars.

It would fail on ndarrays with an numpy.dtype doesn't have is_floating check on the next line in any case, so this gives it a nicer error message.

Also use the common RegisterType functionality to check for resource variables.

PiperOrigin-RevId: 258407797
上级 2bb4f008
......@@ -833,8 +833,15 @@ class GradientTape(object):
Args:
tensor: a Tensor or list of Tensors.
Raises:
ValueError: if it encounters something that is not a tensor.
"""
for t in nest.flatten(tensor):
if not (pywrap_tensorflow.IsTensor(t) or
pywrap_tensorflow.IsVariable(t)):
raise ValueError("Passed in object of type {}, not tf.Tensor".format(
type(t)))
if not t.dtype.is_floating:
logging.log_first_n(
logging.WARN, "The dtype of the watched tensor must be "
......
......@@ -1357,6 +1357,12 @@ class BackpropTest(test.TestCase):
tf_da = gradients.gradients(tf_max, [tf_aa])
self.assertAllEqual(da[0], tf_da[0].eval())
@test_util.run_in_graph_and_eager_modes
def testWatchBadThing(self):
g = backprop.GradientTape()
with self.assertRaisesRegexp(ValueError, 'ndarray'):
g.watch(np.array(1.))
class JacobianTest(test.TestCase):
......
......@@ -57,13 +57,6 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
// This function is not thread-safe.
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
// Registers e as the type of the ResourceVariable class.
// Returns Py_None if registration succeeds, else throws a TypeError and returns
// NULL.
//
// This function is not thread-safe.
PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e);
// Registers e as the VSpace to use.
// `vspace` must be a imperative_grad.py:VSpace named tuple.
PyObject* TFE_Py_RegisterVSpace(PyObject* e);
......
......@@ -716,8 +716,6 @@ PyObject* gradient_function = nullptr;
// Python function that returns output gradients given input gradients.
PyObject* forward_gradient_function = nullptr;
PyTypeObject* resource_variable_type = nullptr;
tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
......@@ -773,23 +771,6 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
Py_RETURN_NONE;
}
PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e) {
if (!PyType_Check(e)) {
PyErr_SetString(
PyExc_TypeError,
"TFE_Py_RegisterResourceVariableType: Need to register a type.");
return nullptr;
}
if (resource_variable_type != nullptr) {
Py_DECREF(resource_variable_type);
}
Py_INCREF(e);
resource_variable_type = reinterpret_cast<PyTypeObject*>(e);
Py_RETURN_NONE;
}
PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
if (fallback_exception_class != nullptr) {
Py_DECREF(fallback_exception_class);
......@@ -2132,7 +2113,7 @@ PyObject* GetPythonObjectFromString(const char* s) {
}
bool CheckResourceVariable(PyObject* item) {
if (PyObject_TypeCheck(item, resource_variable_type)) {
if (tensorflow::swig::IsResourceVariable(item)) {
tensorflow::Safe_PyObjectPtr handle(
PyObject_GetAttrString(item, "_handle"));
return EagerTensor_CheckExact(handle.get());
......
......@@ -1779,7 +1779,7 @@ class UninitializedVariable(BaseResourceVariable):
synchronization=synchronization, aggregation=aggregation)
pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable)
pywrap_tensorflow.RegisterType("ResourceVariable", ResourceVariable)
math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access
......
......@@ -26,6 +26,7 @@ import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
......@@ -1322,6 +1323,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
Variable._OverloadAllOperators() # pylint: disable=protected-access
pywrap_tensorflow.RegisterType("Variable", Variable)
@tf_export(v1=["Variable"])
......
......@@ -62,7 +62,6 @@ limitations under the License.
%rename("%s") TFE_Py_RegisterForwardGradientFunction;
%rename("%s") TFE_Py_RegisterGradientFunction;
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
%rename("%s") TFE_Py_RegisterResourceVariableType;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_FastPathExecute;
%rename("%s") TFE_Py_RecordGradient;
......
......@@ -257,6 +257,26 @@ int IsTensorHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is a ResourceVariable.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsResourceVariableHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is a ResourceVariable.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsVariableHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "Variable");
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
......@@ -812,6 +832,10 @@ bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
bool IsResourceVariable(PyObject* o) {
return IsResourceVariableHelper(o) == 1;
}
bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
......
......@@ -107,16 +107,34 @@ bool IsAttrs(PyObject* o);
// Returns a true if its input is an ops.Tensor.
//
// Args:
// seq: the input to be checked.
// o: the input to be checked.
//
// Returns:
// True if the object is a tensor.
bool IsTensor(PyObject* o);
// Returns a true if its input is a ResourceVariable.
//
// Args:
// o: the input to be checked.
//
// Returns:
// True if the object is a ResourceVariable.
bool IsResourceVariable(PyObject* o);
// Returns a true if its input is a Variable.
//
// Args:
// o: the input to be checked.
//
// Returns:
// True if the object is a Variable.
bool IsVariable(PyObject* o);
// Returns a true if its input is an ops.IndexesSlices.
//
// Args:
// seq: the input to be checked.
// o: the input to be checked.
//
// Returns:
// True if the object is an ops.IndexedSlices.
......
......@@ -34,6 +34,12 @@ limitations under the License.
%unignore tensorflow::swig::IsTensor;
%noexception tensorflow::swig::IsTensor;
%unignore tensorflow::swig::IsResourceVariable;
%noexception tensorflow::swig::IsResourceVariable;
%unignore tensorflow::swig::IsVariable;
%noexception tensorflow::swig::IsVariable;
%feature("docstring") tensorflow::swig::IsSequence
"""Returns true if its input is a collections.Sequence (except strings).
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册