提交 1e0e804c 编写于 作者: S Scott Zhu 提交者: TensorFlower Gardener

Remove unused _SlimRNNCell.

This cell type has been private for a while, and neither used by internal code, nor exposed as a public API.

PiperOrigin-RevId: 203853628
上级 8bc3844a
......@@ -118,7 +118,6 @@ cuda_py_tests(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:rnn",
"//tensorflow/python:rnn_cell",
"//tensorflow/python:variable_scope",
......
......@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import numpy as np
......@@ -35,7 +34,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
......@@ -935,50 +933,6 @@ class DropoutWrapperTest(test.TestCase):
self.assertAllClose(res0[1].h, res1[1].h)
class SlimRNNCellTest(test.TestCase):
def testBasicRNNCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
my_cell = functools.partial(basic_rnn_cell, num_units=2)
# pylint: disable=protected-access
g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
# pylint: enable=protected-access
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g], {
x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])
})
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellMatch(self):
batch_size = 32
input_size = 100
num_units = 10
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inputs = random_ops.random_uniform((batch_size, input_size))
_, initial_state = basic_rnn_cell(inputs, None, num_units)
rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
outputs, state = rnn_cell(inputs, initial_state)
variable_scope.get_variable_scope().reuse_variables()
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
# pylint: disable=protected-access
slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
# pylint: enable=protected-access
slim_outputs, slim_state = slim_cell(inputs, initial_state)
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
self.assertEqual(slim_state.get_shape(), state.get_shape())
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([slim_outputs, slim_state, outputs, state])
self.assertAllClose(res[0], res[2])
self.assertAllClose(res[1], res[3])
def basic_rnn_cell(inputs, state, num_units, scope=None):
if state is None:
if inputs is not None:
......
......@@ -47,7 +47,6 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
......@@ -1330,48 +1329,3 @@ class MultiRNNCell(RNNCell):
array_ops.concat(new_states, 1))
return cur_inp, new_states
class _SlimRNNCell(RNNCell, checkpointable_tracking.NotCheckpointable):
"""A simple wrapper for slim.rnn_cells."""
def __init__(self, cell_fn):
"""Create a SlimRNNCell from a cell_fn.
Args:
cell_fn: a function which takes (inputs, state, scope) and produces the
outputs and the new_state. Additionally when called with inputs=None and
state=None it should return (initial_outputs, initial_state).
Raises:
TypeError: if cell_fn is not callable
ValueError: if cell_fn cannot produce a valid initial state.
"""
if not callable(cell_fn):
raise TypeError("cell_fn %s needs to be callable", cell_fn)
self._cell_fn = cell_fn
self._cell_name = cell_fn.func.__name__
init_output, init_state = self._cell_fn(None, None)
output_shape = init_output.get_shape()
state_shape = init_state.get_shape()
self._output_size = output_shape.with_rank(2)[1].value
self._state_size = state_shape.with_rank(2)[1].value
if self._output_size is None:
raise ValueError("Initial output created by %s has invalid shape %s" %
(self._cell_name, output_shape))
if self._state_size is None:
raise ValueError("Initial state created by %s has invalid shape %s" %
(self._cell_name, state_shape))
@property
def state_size(self):
return self._state_size
@property
def output_size(self):
return self._output_size
def __call__(self, inputs, state, scope=None):
scope = scope or self._cell_name
output, state = self._cell_fn(inputs, state, scope=scope)
return output, state
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册