diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index e2b634ee8f8d18e1e0e43a9e10cb7f2532bbbf12..e656998b7070cca963de443203188053f0c43575 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1465,7 +1465,7 @@ def count_nonzero_v2(input, # pylint: disable=redefined-builtin return cast( reduce_sum( # int64 reduction happens on GPU - to_int64(gen_math_ops.not_equal(input, zero)), + cast(gen_math_ops.not_equal(input, zero), dtypes.int64), axis=axis, keepdims=keepdims), dtype=dtype)