提交 742ec63b 编写于 作者: I Ian Langmore 提交者: TensorFlower Gardener

check_ops BUGFIX: Call convert_to_tensor on args before doing anything else!

Change: 138138065
上级 ee520b4c
......@@ -498,8 +498,8 @@ def assert_greater_equal(x, y, data=None, summarize=None, message=None,
return control_flow_ops.Assert(condition, data, summarize=summarize)
def _assert_rank_condition(x, rank, static_condition, dynamic_condition, data,
summarize, name):
def _assert_rank_condition(
x, rank, static_condition, dynamic_condition, data, summarize):
"""Assert `x` has a rank that satisfies a given condition.
Args:
......@@ -512,8 +512,6 @@ def _assert_rank_condition(x, rank, static_condition, dynamic_condition, data,
data: The tensors to print out if the condition is false. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional).
Defaults to "assert_rank_at_least".
Returns:
Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
......@@ -521,34 +519,30 @@ def _assert_rank_condition(x, rank, static_condition, dynamic_condition, data,
Raises:
ValueError: If static checks determine `x` fails static_condition.
"""
with ops.name_scope(name, 'assert_rank', [x]):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
# Attempt to statically defined rank.
x_rank_static = x.get_shape().ndims
rank_static = tensor_util.constant_value(rank)
# Attempt to statically defined rank.
x_rank_static = x.get_shape().ndims
rank_static = tensor_util.constant_value(rank)
assert_type(rank, dtypes.int32)
assert_type(rank, dtypes.int32)
if rank_static is not None:
if rank_static.ndim != 0:
raise ValueError('Rank must be a scalar')
if rank_static is not None:
if rank_static.ndim != 0:
raise ValueError('Rank must be a scalar')
if x_rank_static is not None:
if not static_condition(x_rank_static, rank_static):
raise ValueError(
'Static rank condition failed', x_rank_static, rank_static)
return control_flow_ops.no_op(name='static_checks_determined_all_ok')
if x_rank_static is not None:
if not static_condition(x_rank_static, rank_static):
raise ValueError(
'Static rank condition failed', x_rank_static, rank_static)
return control_flow_ops.no_op(name='static_checks_determined_all_ok')
condition = dynamic_condition(array_ops.rank(x), rank)
condition = dynamic_condition(array_ops.rank(x), rank)
# Add the condition that `rank` must have rank zero. Prevents the bug where
# someone does assert_rank(x, [n]), rather than assert_rank(x, n).
if rank_static is None:
this_data = ['Rank must be a scalar. Received rank: ', rank]
rank_check = assert_rank(rank, 0, data=this_data)
condition = control_flow_ops.with_dependencies([rank_check], condition)
# Add the condition that `rank` must have rank zero. Prevents the bug where
# someone does assert_rank(x, [n]), rather than assert_rank(x, n).
if rank_static is None:
this_data = ['Rank must be a scalar. Received rank: ', rank]
rank_check = assert_rank(rank, 0, data=this_data)
condition = control_flow_ops.with_dependencies([rank_check], condition)
return control_flow_ops.Assert(condition, data, summarize=summarize)
......@@ -585,29 +579,32 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
Raises:
ValueError: If static checks determine `x` has wrong rank.
"""
message = message or ''
with ops.name_scope(name, 'assert_rank', [x]):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
message = message or ''
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
dynamic_condition = math_ops.equal
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
dynamic_condition = math_ops.equal
if data is None:
data = [
message,
'Tensor %s must have rank' % x.name, rank, 'Received shape: ',
array_ops.shape(x)
]
if data is None:
data = [
message,
'Tensor %s must have rank' % x.name, rank, 'Received shape: ',
array_ops.shape(x)
]
try:
assert_op = _assert_rank_condition(x, rank, static_condition,
dynamic_condition, data, summarize, name)
try:
assert_op = _assert_rank_condition(x, rank, static_condition,
dynamic_condition, data, summarize)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%s. Tensor %s must have rank %d. Received rank %d, shape %s' %
(message, x.name, e.args[2], e.args[1], x.get_shape()))
else:
raise
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%s. Tensor %s must have rank %d. Received rank %d, shape %s' %
(message, x.name, e.args[2], e.args[1], x.get_shape()))
else:
raise
return assert_op
......@@ -646,28 +643,31 @@ def assert_rank_at_least(
Raises:
ValueError: If static checks determine `x` has wrong rank.
"""
message = message or ''
with ops.name_scope(name, 'assert_rank_at_least', [x]):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
message = message or ''
static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
dynamic_condition = math_ops.greater_equal
if data is None:
data = [
message,
'Tensor %s must have rank at least' % x.name, rank,
'Received shape: ', array_ops.shape(x)
]
static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
dynamic_condition = math_ops.greater_equal
if data is None:
data = [
message,
'Tensor %s must have rank at least' % x.name, rank,
'Received shape: ', array_ops.shape(x)
]
try:
assert_op = _assert_rank_condition(x, rank, static_condition,
dynamic_condition, data, summarize, name)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%s. Tensor %s must have rank at least %d. Received rank %d, shape '
'%s' % (message, x.name, e.args[2], e.args[1], x.get_shape()))
else:
raise
try:
assert_op = _assert_rank_condition(x, rank, static_condition,
dynamic_condition, data, summarize)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%s. Tensor %s must have rank at least %d. Received rank %d, '
'shape %s' % (message, x.name, e.args[2], e.args[1], x.get_shape()))
else:
raise
return assert_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册