importer_test.py 22.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

16 17
"""Tests for tensorflow.python.framework.importer."""

18 19 20 21
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

22 23
import tensorflow.python.platform

24
import numpy as np
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
import tensorflow as tf

from google.protobuf import text_format

from tensorflow.core.framework import op_def_pb2
from tensorflow.python.framework import device
from tensorflow.python.framework import op_def_registry


_op_list = op_def_pb2.OpList()
text_format.Merge("""
  op {
    name: 'None'
  }
  op {
    name: 'Oi'
    output_arg { name: 'a' type: DT_INT32 }
  }
  op {
    name: 'Or'
    output_arg { name: 'a' type: DT_INT32 is_ref: true }
  }
  op {
    name: 'Of'
    output_arg { name: 'a' type: DT_FLOAT }
  }
  op {
    name: 'Ii'
    input_arg { name: 'a' type: DT_INT32 }
  }
  op {
    name: 'If'
    input_arg { name: 'a' type: DT_FLOAT }
  }
  op {
    name: 'Oii'
    output_arg { name: 'a' type: DT_INT32 }
    output_arg { name: 'b' type: DT_INT32 }
  }
  op {
    name: 'Oif'
    output_arg { name: 'a' type: DT_INT32 }
    output_arg { name: 'b' type: DT_FLOAT }
  }
  op {
    name: 'Iii'
    input_arg { name: 'a' type: DT_INT32 }
    input_arg { name: 'b' type: DT_INT32 }
  }
  op {
    name: 'Iff'
    input_arg { name: 'a' type: DT_FLOAT }
    input_arg { name: 'b' type: DT_FLOAT }
  }
  op {
    name: 'Iif'
    input_arg { name: 'a' type: DT_INT32 }
    input_arg { name: 'b' type: DT_FLOAT }
  }
  op {
    name: 'Iri'
    input_arg { name: 'a' type: DT_INT32 is_ref: true }
    input_arg { name: 'b' type: DT_INT32 }
  }
  op {
    name: 'In'
    input_arg { name: 'a' number_attr: 'N' type_attr: 'T' }
    attr { name: 'N' type: 'int' minimum: 1 }
    attr { name: 'T' type: 'type' }
  }
  op {
    name: 'Otl'
    output_arg { name: 'a' type_list_attr: 't' }
    attr { name: 'T' type: 'list(type)' minimum: 1 }
  }
  op {
    name: 'Unary'
    input_arg { name: 'a' type_attr: 'T' }
    output_arg { name: 'b' type_attr: 'T' }
    attr { name: 'T' type: 'type' }
  }
""", _op_list)
op_def_registry.register_op_list(_op_list)
# NOTE(mrry): Dummy shape registrations for ops used in the tests.
for op_def in _op_list.op:
  tf.RegisterShape(op_def.name)(None)

class ImportGraphDefTest(tf.test.TestCase):

114 115
  def _MakeGraphDef(self, text, version=tf.GRAPH_DEF_VERSION):
    text = "version: %d\n%s" % (version, text)
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
    ret = tf.GraphDef()
    text_format.Merge(text, ret)
    return ret

  def testBasic(self):
    with tf.Graph().as_default():
      a, b, c, d = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'Oif' }
          node { name: 'B' op: 'Otl'
                 attr { key: 't'
                        value { list { type: DT_INT32 type: DT_FLOAT } } } }
          node { name: 'C' op: 'In'
                 attr { key: 'N' value { i: 2 } }
                 attr { key: 'T' value { type: DT_INT32 } }
                 input: 'A:0' input: 'B:0' }
          node { name: 'D' op: 'In'
                 attr { key: 'N' value { i: 2 } }
                 attr { key: 'T' value { type: DT_FLOAT } }
                 input: 'A:1' input: 'B:1' }
          """),
          return_elements=['A', 'B', 'C', 'D'],
          name='import')

      # Assert that the import process creates distinct tensors.
      self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
      self.assertNotEqual(b.outputs[0].name, b.outputs[1].name)
      self.assertNotEqual(a.outputs[0].name, b.outputs[0].name)
      self.assertNotEqual(a.outputs[0].name, b.outputs[1].name)
      self.assertNotEqual(a.outputs[1].name, b.outputs[0].name)
      self.assertNotEqual(a.outputs[1].name, b.outputs[1].name)

      # Assert that the ops are connected according to the GraphDef topology.
      self.assertEqual(c.inputs[0], a.outputs[0])
      self.assertEqual(c.inputs[1], b.outputs[0])
      self.assertEqual(d.inputs[0], a.outputs[1])
      self.assertEqual(d.inputs[1], b.outputs[1])

      # Check the types of the returned ops and tensors.
      self.assertEqual(a.type, 'Oif')
      self.assertEqual(b.type, 'Otl')
      self.assertEqual(c.type, 'In')
      self.assertEqual(d.type, 'In')
      self.assertEqual(a.outputs[0].dtype, tf.int32)
      self.assertEqual(a.outputs[1].dtype, tf.float32)
      self.assertEqual(b.outputs[0].dtype, tf.int32)
      self.assertEqual(b.outputs[1].dtype, tf.float32)

      # Check the names of the returned ops.
      self.assertEqual(a.name, 'import/A')
      self.assertEqual(b.name, 'import/B')
      self.assertEqual(c.name, 'import/C')
      self.assertEqual(d.name, 'import/D')

  def testInputMap(self):
    with tf.Graph().as_default():
      feed_a_0 = tf.constant(0, dtype=tf.int32)
      feed_b_1 = tf.constant(1, dtype=tf.int32)

      a, b, c, d = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'Oii' }
          node { name: 'B' op: 'Oii' }
          node { name: 'C' op: 'In'
                 attr { key: 'N' value { i: 2 } }
                 attr { key: 'T' value { type: DT_INT32 } }
                 input: 'A:0' input: 'B:0' }
          node { name: 'D' op: 'In'
                 attr { key: 'N' value { i: 2 } }
                 attr { key: 'T' value { type: DT_INT32 } }
                 input: 'A:1' input: 'B:1' }
          """),
          input_map={'A:0': feed_a_0, 'B:1': feed_b_1},
          return_elements=['A', 'B', 'C', 'D'])

      self.assertEqual(c.inputs[0], feed_a_0)
      self.assertEqual(c.inputs[1], b.outputs[0])
      self.assertEqual(d.inputs[0], a.outputs[1])
      self.assertEqual(d.inputs[1], feed_b_1)

V
Vijay Vasudevan 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
  def testInputMapBytes(self):
    with tf.Graph().as_default():
      feed_a_0 = tf.constant(0, dtype=tf.int32)
      feed_b_1 = tf.constant(1, dtype=tf.int32)

      a, b, c, d = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'Oii' }
          node { name: 'B' op: 'Oii' }
          node { name: 'C' op: 'In'
                 attr { key: 'N' value { i: 2 } }
                 attr { key: 'T' value { type: DT_INT32 } }
                 input: 'A:0' input: 'B:0' }
          node { name: 'D' op: 'In'
                 attr { key: 'N' value { i: 2 } }
                 attr { key: 'T' value { type: DT_INT32 } }
                 input: 'A:1' input: 'B:1' }
          """),
          input_map={b'A:0': feed_a_0, b'B:1': feed_b_1},
          return_elements=[b'A', b'B', b'C', b'D'])

      self.assertEqual(c.inputs[0], feed_a_0)
      self.assertEqual(c.inputs[1], b.outputs[0])
      self.assertEqual(d.inputs[0], a.outputs[1])
      self.assertEqual(d.inputs[1], feed_b_1)

  def testInputMapUnicode(self):
    with tf.Graph().as_default():
      feed_a_0 = tf.constant(0, dtype=tf.int32)
      feed_b_1 = tf.constant(1, dtype=tf.int32)

      a, b, c, d = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'Oii' }
          node { name: 'B' op: 'Oii' }
          node { name: 'C' op: 'In'
                 attr { key: 'N' value { i: 2 } }
                 attr { key: 'T' value { type: DT_INT32 } }
                 input: 'A:0' input: 'B:0' }
          node { name: 'D' op: 'In'
                 attr { key: 'N' value { i: 2 } }
                 attr { key: 'T' value { type: DT_INT32 } }
                 input: 'A:1' input: 'B:1' }
          """),
          input_map={u'A:0': feed_a_0, u'B:1': feed_b_1},
          return_elements=[u'A', u'B', u'C', u'D'])

      self.assertEqual(c.inputs[0], feed_a_0)
      self.assertEqual(c.inputs[1], b.outputs[0])
      self.assertEqual(d.inputs[0], a.outputs[1])
      self.assertEqual(d.inputs[1], feed_b_1)

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
  def testImplicitZerothOutput(self):
    with tf.Graph().as_default():
      a, b = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'Oii' }
          node { name: 'B' op: 'Ii' input: 'A' }
          """),
          return_elements=['A', 'B'])

      self.assertEqual(b.inputs[0], a.outputs[0])

  def testInputMapImplicitZerothOutput(self):
    with tf.Graph().as_default():
      feed_a_0 = tf.constant(0, dtype=tf.int32)
      b, = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'Oii' }
          node { name: 'B' op: 'Ii' input: 'A:0' }
          """),
          input_map={'A': feed_a_0},
          return_elements=['B'])

      self.assertEqual(b.inputs[0], feed_a_0)

  def testWithControlDependency(self):
    with tf.Graph().as_default():
      a, b = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'None' }
          node { name: 'B' op: 'None' input: '^A' }
          """),
          return_elements=['A', 'B'])

      self.assertEqual(b.control_inputs, [a])

  def testWithRefs(self):
    with tf.Graph().as_default():
      a, b, c, d = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'Or' }
          node { name: 'B' op: 'Oi' }
          node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' }
          node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' }
          """),
          return_elements=['A', 'B', 'C', 'D'])

      self.assertEqual(c.inputs[0], a.outputs[0])
      self.assertEqual(c.inputs[1], b.outputs[0])
      self.assertEqual(d.inputs[0], a.outputs[0])
      self.assertEqual(d.inputs[1], b.outputs[0])

      self.assertEqual(a.outputs[0].dtype, tf.int32_ref)
      self.assertEqual(c._input_dtypes, [tf.int32, tf.int32])
      self.assertEqual(c.outputs, [])
      self.assertEqual(d._input_dtypes,
                       [tf.int32_ref, tf.int32])
      self.assertEqual(d.outputs, [])

  def testCyclic(self):
    with tf.Graph().as_default():
      a, b = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'Unary'
                 attr { key: 'T' value { type: DT_INT32 } } input: 'B:0' }
          node { name: 'B' op: 'Unary'
                 attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' }
          """),
          return_elements=['A', 'B'])

      self.assertEqual(a.inputs[0], b.outputs[0])
      self.assertEqual(b.inputs[0], a.outputs[0])

  def testTypeMismatchInGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'Oi' }
            node { name: 'B' op: 'If' input: 'A:0' }
            """))
      self.assertTrue(
          'Cannot convert a tensor of type int32 to an input of type float' in
          str(e.exception))

  def testInvalidSignatureTooManyInputsInGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'Oi' }
            node { name: 'B' op: 'None' input: 'A:0' }
            """))
340
      self.assertTrue('More inputs specified (\'A:0\') than the op expects' in
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
                      str(e.exception))

  def testInvalidSignatureNotEnoughInputsInGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'Oi' }
            node { name: 'B' op: 'Iif' input: 'A:0' }
            """))
      self.assertTrue('Input types mismatch (expected \'int32, float32\' but '
                      'got \'int32\')' in str(e.exception))

  def testMissingInputOpInGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'B' op: 'If' input: 'A:0' }
            """))
361
      self.assertTrue("Input tensor 'A:0' not found" in str(e.exception))
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381

  def testMissingInputOpInGraphDefButAppearsInInputMap(self):
    with tf.Graph().as_default():
      feed_a_0 = tf.constant(5.0)
      b, = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'B' op: 'If' input: 'A:0' }
          """),
          input_map={'A:0': feed_a_0},
          return_elements=['B'])
      self.assertEqual(b.inputs[0], feed_a_0)

  def testMissingInputTensorInGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'Of' }
            node { name: 'B' op: 'If' input: 'A:1' }
            """))
382
      self.assertTrue("Input tensor 'A:1' not found" in str(e.exception))
383 384 385 386 387 388 389 390

  def testMissingControlInputInGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'B' op: 'None' input: '^A' }
            """))
391
      self.assertTrue("Control input '^A' not found" in str(e.exception))
392 393 394 395 396 397 398 399

  def testInvalidTensorNameOutputIndexInGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'B' op: 'None' input: 'A:B' }
            """))
400 401
      self.assertEqual("Cannot convert 'A:B' to a tensor name.",
                       str(e.exception))
402 403 404 405 406 407 408 409

  def testInvalidTensorNameInGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'B' op: 'None' input: 'A:B:0' }
            """))
410 411
      self.assertEqual("Cannot convert 'A:B:0' to a tensor name.",
                       str(e.exception))
412 413 414 415 416 417 418 419 420

  def testMissingReturnOperation(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'None' }
            """),
            return_elements=['B'])
421
      self.assertTrue("return_element 'B' not found in graph_def." in
422 423 424 425 426 427 428 429 430 431
                      str(e.exception))

  def testMissingReturnTensor(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'Oi' }
            """),
            return_elements=['A:1'])
432
      self.assertTrue("return_element 'A:1' not found in graph_def." in
433 434 435 436 437 438 439 440
                      str(e.exception))

      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'Oi' }
            """),
            return_elements=['B:0'])
441
      self.assertTrue("return_element 'B:0' not found in graph_def." in
442 443 444 445 446 447 448 449
                      str(e.exception))

      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'Oi' }
            """),
            return_elements=['A:B:0'])
450
      self.assertTrue("return_element 'A:B:0' not found in graph_def." in
451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
                      str(e.exception))

  def testMissingInputMap(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'None' }
            """),
            input_map={'B:0': tf.constant(5.0)})
      self.assertTrue('not found in graph_def: [B:0]' in str(e.exception))

  def testInputMapTypeMismatch(self):
    with tf.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        tf.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'Oi' }
            node { name: 'B' op: 'Ii' input: 'A:0' }
            """),
            input_map={'A:0': tf.constant(5.0)})
      self.assertTrue(
          'Cannot convert a tensor of type float32 to an input of type int32.'
          in str(e.exception))

  def testNoReturns(self):
    with tf.Graph().as_default() as g:
      ret = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'None' }
          """))
      self.assertEqual(ret, None)

      a = g.get_operation_by_name('import/A')
      self.assertEqual(a.type, 'None')

  def testOverrideNamePrefix(self):
    with tf.Graph().as_default():
      a, = tf.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'None' }
          """),
          return_elements=['A'], name='imported_graph')
      self.assertEqual(a.name, 'imported_graph/A')

  def testEmptyGraph(self):
    with tf.Graph().as_default() as g:
      init_version = g.version
      tf.import_graph_def(self._MakeGraphDef(''))
      self.assertEqual(init_version, g.version)

  def testInvalidInputForGraphDef(self):
    with tf.Graph().as_default():
      with self.assertRaises(TypeError) as e:
        tf.import_graph_def('')
      self.assertEqual(
          'graph_def must be a GraphDef proto.', str(e.exception))

  def testInvalidInputForInputMap(self):
    with tf.Graph().as_default():
      with self.assertRaises(TypeError) as e:
        tf.import_graph_def(self._MakeGraphDef(''),
                                input_map=[tf.constant(5.0)])
      self.assertEqual('input_map must be a dictionary mapping strings to '
                       'Tensor objects.', str(e.exception))

  def testInvalidInputForReturnOperations(self):
    with tf.Graph().as_default():
      with self.assertRaises(TypeError) as e:
        tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7])
      self.assertEqual(
          'return_elements must be a list of strings.', str(e.exception))

  def testWithExtensionAndAttr(self):
    with tf.Graph().as_default() as g:
      c = tf.constant(5.0, dtype=tf.float32, name='c')
      tf.pack([c, c], name='pack')
    gdef = g.as_graph_def()

    with self.test_session():
      pack, = tf.import_graph_def(gdef, return_elements=['pack'])
      self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])

  def testWithDevice(self):
    with tf.Graph().as_default() as g:
      # No device.
      a = tf.constant(3.0, name='a')

      with tf.device('/cpu:0'):
        b = tf.constant(4.0, name='b')
      with tf.device('/job:worker'):
        c = tf.constant(5.0, name='c')

    gdef = g.as_graph_def()

    with tf.Graph().as_default():
      a2, b2, c2 = tf.import_graph_def(
          gdef, return_elements=['a', 'b', 'c'])
      self.assertEqual(a.device, a2.device)
      self.assertEqual(b.device, b2.device)
      self.assertEqual(c.device, c2.device)

    with tf.Graph().as_default():
      with tf.device(device.merge_device('/task:0')):
        a3, b3, c3 = tf.import_graph_def(
            gdef, return_elements=['a', 'b', 'c'])
        self.assertEqual('/task:0', a3.device)
        self.assertEqual('/task:0/device:CPU:0', b3.device)  # canonicalized.
        self.assertEqual(c.device + '/task:0', c3.device)

    with tf.Graph().as_default():
      with tf.device(device.merge_device('/job:ps')):
        a4, b4, c4 = tf.import_graph_def(
            gdef, return_elements=['a', 'b', 'c'])
        self.assertEqual('/job:ps', a4.device)
        self.assertEqual('/job:ps/device:CPU:0', b4.device)  # canonicalized.
        self.assertEqual(c.device, c4.device)  # worker overrides ps.

    with tf.Graph().as_default():
      with tf.device(device.merge_device('/gpu:0')):
        a5, b5, c5 = tf.import_graph_def(
            gdef, return_elements=['a', 'b', 'c'])
        self.assertEqual('/device:GPU:0', a5.device)
        self.assertEqual('/device:CPU:0', b5.device)  # cpu overrides gpu.
        self.assertEqual(c.device + '/device:GPU:0', c5.device)

  def testGradient(self):
    with tf.Graph().as_default() as g:
      inputs = tf.placeholder(tf.float32, shape=[None, 100], name="input")
      weights = tf.placeholder(tf.float32, shape=[100, 10], name="weights")
      biases = tf.placeholder(tf.float32, shape=[10], name="biases")
      activations = tf.nn.relu(tf.matmul(inputs, weights) + biases,
                               name="activations")
      loss = tf.reduce_mean(activations, name="loss")
    gdef = g.as_graph_def()

    with tf.Graph().as_default() as g:
      input_placeholder = tf.placeholder(tf.float32, shape=[32, 100])
      weights_var = tf.Variable(tf.truncated_normal([100, 10]), name="weights")
      biases_var = tf.Variable(tf.zeros(10), name="biases")
      activations, loss = tf.import_graph_def(
          gdef,
          input_map={"input:0": input_placeholder,
                     "weights:0": weights_var,
                     "biases:0": biases_var},
          return_elements=["activations:0", "loss:0"])
      self.assertEqual([32, 10], activations.get_shape())
      self.assertEqual([], loss.get_shape())
      weights_grad, biases_grad = tf.gradients(loss, [weights_var, biases_var])
      self.assertEqual([100, 10], weights_grad.get_shape())
      self.assertEqual([10], biases_grad.get_shape())

  def testLargeGraph(self):
    with self.test_session():
      # The default message byte limit is 64M. Ours is 2G with a warning at 512.
      # Adding a 150M entries float32 tensor should blow through the warning,
      # but not the hard limit.
      input_shape = [150, 1024, 1024]
609
      tensor_input = np.random.rand(*input_shape).astype(np.float32)
610 611 612 613
      t = tf.constant(tensor_input, shape=input_shape)
      g = tf.identity(t)
      g.eval()

614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635
  def testVersion(self):
    for version in tf.GRAPH_DEF_VERSION_MIN, tf.GRAPH_DEF_VERSION_MAX:
      with tf.Graph().as_default():
        a, = tf.import_graph_def(
            self._MakeGraphDef("node { name: 'A' op: 'Oii' }", version=version),
            return_elements=['A'])
        self.assertEqual(a.graph.graph_def_version, version)

  def testVersionLow(self):
    with tf.Graph().as_default():
      pat = (r"^GraphDef version -1 is no longer supported: TensorFlow \S+ "
             r"needs \d+ <= version <= \d+.  Please regenerate your graph.$")
      with self.assertRaisesRegexp(ValueError, pat):
        tf.import_graph_def(self._MakeGraphDef("", version=-1))

  def testVersionHigh(self):
    with tf.Graph().as_default():
      pat = (r"^GraphDef version \d+ is not yet supported: TensorFlow \S+ "
             r"needs \d+ <= version <= \d+.  Please upgrade TensorFlow.$")
      with self.assertRaisesRegexp(ValueError, pat):
        tf.import_graph_def(self._MakeGraphDef("", version=1 << 30))

636 637 638

if __name__ == '__main__':
  tf.test.main()