aggregate_ops_test.py 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for aggregate_ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

23 24
from tensorflow.core.framework import tensor_pb2
from tensorflow.python.framework import constant_op
25
from tensorflow.python.framework import dtypes
26
from tensorflow.python.framework import tensor_shape
27
from tensorflow.python.framework import test_util
28 29
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
30
from tensorflow.python.ops import string_ops
31 32 33 34 35 36 37 38 39 40 41 42
from tensorflow.python.platform import test


class AddNTest(test.TestCase):
  # AddN special-cases adding the first M inputs to make (N - M) divisible by 8,
  # after which it adds the remaining (N - M) tensors 8 at a time in a loop.
  # Test N in [1, 10] so we check each special-case from 1 to 9 and one
  # iteration of the loop.
  _MAX_N = 10

  def _supported_types(self):
    if test.is_gpu_available():
43 44 45 46
      return [
          dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
          dtypes.complex128, dtypes.int64
      ]
47 48 49 50 51 52 53 54 55 56 57 58 59 60
    return [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
            dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
            dtypes.complex128]

  def _buildData(self, shape, dtype):
    data = np.random.randn(*shape).astype(dtype.as_numpy_dtype)
    # For complex types, add an index-dependent imaginary component so we can
    # tell we got the right value.
    if dtype.is_complex:
      return data + 10j * data
    return data

  def testAddN(self):
    np.random.seed(12345)
61
    with self.session(use_gpu=True) as sess:
62 63 64
      for dtype in self._supported_types():
        for count in range(1, self._MAX_N + 1):
          data = [self._buildData((2, 2), dtype) for _ in range(count)]
65
          actual = self.evaluate(math_ops.add_n(data))
66 67 68 69 70
          expected = np.sum(np.vstack(
              [np.expand_dims(d, 0) for d in data]), axis=0)
          tol = 5e-3 if dtype == dtypes.float16 else 5e-7
          self.assertAllClose(expected, actual, rtol=tol, atol=tol)

71
  @test_util.run_deprecated_v1
72 73
  def testUnknownShapes(self):
    np.random.seed(12345)
74
    with self.session(use_gpu=True) as sess:
75 76 77 78 79 80 81 82 83 84
      for dtype in self._supported_types():
        data = self._buildData((2, 2), dtype)
        for count in range(1, self._MAX_N + 1):
          data_ph = array_ops.placeholder(dtype=dtype)
          actual = sess.run(math_ops.add_n([data_ph] * count), {data_ph: data})
          expected = np.sum(np.vstack([np.expand_dims(data, 0)] * count),
                            axis=0)
          tol = 5e-3 if dtype == dtypes.float16 else 5e-7
          self.assertAllClose(expected, actual, rtol=tol, atol=tol)

85
  @test_util.run_deprecated_v1
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
  def testVariant(self):

    def create_constant_variant(value):
      return constant_op.constant(
          tensor_pb2.TensorProto(
              dtype=dtypes.variant.as_datatype_enum,
              tensor_shape=tensor_shape.TensorShape([]).as_proto(),
              variant_val=[
                  tensor_pb2.VariantTensorDataProto(
                      # Match registration in variant_op_registry.cc
                      type_name=b"int",
                      metadata=np.array(value, dtype=np.int32).tobytes())
              ]))

    # TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant
    # copying between CPU and GPU is supported.
102
    with self.session(use_gpu=False):
103 104 105 106 107 108 109 110 111
      num_tests = 127
      values = list(range(100))
      variant_consts = [create_constant_variant(x) for x in values]
      sum_count_indices = np.random.randint(1, 29, size=num_tests)
      sum_indices = [
          np.random.randint(100, size=count) for count in sum_count_indices]
      expected_sums = [np.sum(x) for x in sum_indices]
      variant_sums = [math_ops.add_n([variant_consts[i] for i in x])
                      for x in sum_indices]
112

113 114 115 116
      # We use as_string() to get the Variant DebugString for the
      # variant_sums; we know its value so we can check via string equality
      # here.
      #
117 118 119
      # Right now, non-numpy-compatible objects cannot be returned from a
      # session.run call; similarly, objects that can't be converted to
      # native numpy types cannot be passed to ops.convert_to_tensor.
120 121 122 123 124
      variant_sums_string = string_ops.as_string(variant_sums)
      self.assertAllEqual(
          variant_sums_string,
          ["Variant<type: int value: {}>".format(s).encode("utf-8")
           for s in expected_sums])
125

126 127 128

if __name__ == "__main__":
  test.main()