提交 54a644f3 编写于 作者: V Vijay Vasudevan

TensorFlow: upstream changes to git

Change 109366961
	TensorFlow BUILD: now that we have an ops library,
	set linkstatic to 1. This fixes a breakage in the would-be
	opensource build, and it *might* mean we can get rid of
	all of the RequireDefaultOps() calls in our code.

	The ops library is much smaller than the kernels library that was
	previously linked together.  We set linkstatic=0 presumably since we
	didn't want to package a static copy of the kernels (very large)
	everywhere.  But the op definitions are small, so this seems like a
	safe change to make.  Time to build the various tests was not
	any longer after this change, and inspecting the example_trainer
	binary showed no large increase.
Change 109363613
	TensorFlow: new graph_def_builder_test needs to RequireDefaultOps.
Change 109362569
	Split ":ops" out of ":kernels" target in tensorflow/core.
Change 109360666
	Catch dtype and some shape errors sooner in `QueueBase`.

	Some avoidable errors were not being caught (e.g. the dtypes of the
	enqueue components were not checked against the queue's dtypes in
	Python), leading to cryptic messages at runtime. After this CL, they
	will be caught earlier.
Change 109359569
	TensorFlow: Expect g_ != nullptr in test
Change 109350735
	Add a version number to GraphDef

	We would like to be able to deprecate behavior in newly generated graphs
	without invalidating tensorflow's ability to read and evaluate old graphs.
	For this purpose, GraphDef now has a version field which can be checked inside
	op kernels to determine how backwards compatible to be.  version.h defines
	TF_GRAPHDEF_VERSION_MIN and TF_GRAPHDEF_VERSION_MAX specifying the range of
	supported GraphDef versions in the current version of tensorflow.

	Also expose tf.__version__ and tf.__graph_def_version{,_min,_max}__ for Python
	interrogation purposes.

	Whenever we want to deprecate or change some GraphDef semantics, we will
	proceed as follows:

	1. Bump TF_GRAPHDEF_VERSION_MAX, leaving TF_GRAPHDEF_VERSION_MIN unchanged.
	   Describe the change in graph.proto, include the date introduced.

	2. In each relevant kernel, implement the new behavior if the GraphDef version
	   is new, but preserve the old behavior for previous GraphDef versions.

	3. Wait six months or so (we need to formalize this somewhere).

	4. Bump TF_GRAPHDEF_VERSION_MIN and remove the backwards compatibility.

	The GraphDef version is distinct from the open source version, but at least
	(4) and possibly (1) correspond to major version number bumps.

	The first GraphDef version bump is the upcoming scalar strictness change,
	which affects Google users only since open source is already scalar strict.

	This commit does not yet plumb the version number into OpKernelConstruction
	so that ops can access it.  That will follow.
Change 109350260
	Made TensorShapeProto implicitly convertible to TensorShape.

Base CL: 109366982
上级 eb5e56e4
......@@ -187,10 +187,7 @@ cc_library(
"graph/testlib.h",
],
copts = tf_copts(),
visibility = [
":friends",
"//tensorflow:internal",
],
visibility = ["//visibility:public"],
deps = [
":core_cpu",
":tensorflow",
......@@ -213,11 +210,9 @@ tf_cuda_library(
)
tf_cuda_library(
name = "kernels",
name = "ops",
srcs = glob(
[
"kernels/**/*.h",
"kernels/**/*.cc",
"ops/**/*.h",
"ops/**/*.cc",
"user_ops/**/*.h",
......@@ -226,14 +221,38 @@ tf_cuda_library(
exclude = [
"**/*test*",
"**/*main.cc",
"kernels/**/*.cu.cc",
"user_ops/**/*.cu.cc",
],
),
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":core",
":lib",
":protos_cc",
"//tensorflow/models/embedding:word2vec_ops",
"//third_party/eigen3",
],
alwayslink = 1,
)
tf_cuda_library(
name = "kernels",
srcs = glob(
[
"kernels/**/*.h",
"kernels/**/*.cc",
],
exclude = [
"**/*test*",
"**/*main.cc",
"kernels/**/*.cu.cc",
],
),
copts = tf_copts(),
cuda_deps = [
":gpu_kernels",
":cuda",
],
linkstatic = 0,
visibility = ["//visibility:public"],
......@@ -241,10 +260,10 @@ tf_cuda_library(
"@gemmlowp//:eight_bit_int_gemm",
":core",
":lib",
":ops",
":protos_cc",
":stream_executor",
"//tensorflow/models/embedding:word2vec_kernels",
"//tensorflow/models/embedding:word2vec_ops",
"//third_party/eigen3",
],
alwayslink = 1,
......@@ -262,6 +281,7 @@ tf_gpu_kernel_library(
),
visibility = ["//visibility:public"],
deps = [
":cuda",
"//third_party/eigen3",
],
)
......@@ -416,6 +436,7 @@ tf_cc_tests(
":direct_session",
":kernels",
":lib",
":ops",
":strict_headers",
":test_main",
":testlib",
......
......@@ -164,6 +164,11 @@ Status DirectSession::Extend(const GraphDef& graph) {
}
Status DirectSession::ExtendLocked(const GraphDef& graph) {
if (graph_created_ && graph_def_.version() != graph.version()) {
return errors::InvalidArgument("Incompatible GraphDef versions in Extend: ",
graph_def_.version(), " != ",
graph.version());
}
graph_created_ = true; // In case this is first call
graph_def_.MergeFrom(graph);
return Status::OK();
......
......@@ -980,6 +980,7 @@ static void ToGraphDef(const Graph* g, GraphDef* gdef) {
}
gtl::InlinedVector<const Edge*, 4> inputs;
gdef->Clear();
gdef->set_version(g->version());
while (!ready.empty()) {
const Node* n = ready.front();
ready.pop_front();
......
......@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace test {
......@@ -27,6 +28,7 @@ typedef FunctionDefHelper FDH;
GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
gtl::ArraySlice<FunctionDef> funcs) {
GraphDef g;
g.set_version(TF_GRAPH_DEF_VERSION);
for (auto n : nodes) {
*(g.add_node()) = n;
}
......
......@@ -15,6 +15,17 @@ import "tensorflow/core/framework/function.proto";
message GraphDef {
repeated NodeDef node = 1;
// Compatibility version of the graph. Newly created graphs use
// the most recent version. Version history:
//
// 0. Graphs created before GraphDef versioning
// 1. First real version (2dec2015)
//
// The GraphDef version is distinct from the TensorFlow version.
// Each released version of TensorFlow will support a range of
// GraphDef versions.
int32 version = 3;
// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
//
// "library" provides user-defined functions.
......
......@@ -24,6 +24,7 @@ namespace tensorflow {
string SummarizeGraphDef(const GraphDef& graph_def) {
string ret;
strings::StrAppend(&ret, "version = ", graph_def.version(), ";\n");
for (const NodeDef& node : graph_def.node()) {
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
}
......
......@@ -26,6 +26,14 @@ namespace tensorflow {
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff) {
if (actual.version() != expected.version()) {
if (diff != nullptr) {
*diff = strings::StrCat("Expected version ", expected.version(),
", got version ", actual.version());
}
return false;
}
std::unordered_map<string, const NodeDef*> actual_index;
for (const NodeDef& node : actual.node()) {
actual_index[node.name()] = &node;
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
......@@ -88,10 +89,11 @@ TEST_F(EqualGraphDefTest, ExtraNode) {
Input(a_.opts().WithName("A"));
Input(a_.opts().WithName("B"));
EXPECT_FALSE(Match());
EXPECT_EQ(
"Found unexpected node 'B = Input[]()' not in expected graph:\n"
"A = Input[]();\n",
diff_);
EXPECT_EQ(strings::StrCat(
"Found unexpected node 'B = Input[]()' not in expected graph:\n"
"version = ",
TF_GRAPH_DEF_VERSION, ";\nA = Input[]();\n"),
diff_);
}
TEST_F(EqualGraphDefTest, NodeOrder) {
......
......@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
......@@ -105,7 +106,7 @@ Node::Properties::~Properties() {}
// Graph
Graph::Graph(const OpRegistryInterface* ops)
: ops_(ops), arena_(8 << 10 /* 8kB */) {
: ops_(ops), version_(TF_GRAPH_DEF_VERSION), arena_(8 << 10 /* 8kB */) {
// Source and sink have no endpoints, just control edges.
NodeDef def;
def.set_name("_SOURCE");
......@@ -253,6 +254,7 @@ void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
void Graph::ToGraphDef(GraphDef* graph_def) const {
graph_def->Clear();
graph_def->set_version(version());
std::vector<const Edge*>
inputs; // Construct this outside the loop for speed.
for (const Node* node : nodes()) {
......
......@@ -187,11 +187,17 @@ class Graph {
// single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
//
// The graph can hold ops found in registry.
//
// The version defaults to TF_GRAPH_DEF_VERSION.
explicit Graph(const OpRegistryInterface* registry);
~Graph();
static const int kControlSlot = -1;
// The GraphDef version of this graph (see graph.proto).
int version() const { return version_; }
void set_version(int version) { version_ = version; }
// Adds a new node to this graph, and returns it. Infers the Op and
// input/output types for the node. *this owns the returned instance.
// Returns nullptr and sets *status on error.
......@@ -274,6 +280,9 @@ class Graph {
// Registry of all known ops. Not owned.
const OpRegistryInterface* const ops_;
// GraphDef version
int version_;
// Allocator which will give us good locality.
core::Arena arena_;
......
......@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
......@@ -45,6 +46,19 @@ class GraphConstructor {
GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef,
Graph* g, Status* status)
: opts_(opts), gdef_(gdef), g_(g), status_(status) {
const int version = gdef->version();
if (!(TF_GRAPH_DEF_VERSION_MIN <= version &&
version <= TF_GRAPH_DEF_VERSION_MAX)) {
bool low = version < TF_GRAPH_DEF_VERSION_MAX;
*status = errors::InvalidArgument(
"GraphDef version ", version, " is ", low ? "no longer" : "not yet",
" supported: TensorFlow ", TF_VERSION_STRING, " needs ",
TF_GRAPH_DEF_VERSION_MAX, " <= version <= ", TF_GRAPH_DEF_VERSION_MIN,
". ",
low ? "Please regenerate your graph." : "Please upgrade TensorFlow.");
return;
}
g->set_version(gdef->version());
BuildNodeIndex();
InitFromEdges();
Convert();
......
......@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/version.h"
// TODO(josh11b): Test InitCostModel().
// TODO(josh11b): Test setting the "device" field of a NodeDef.
......@@ -58,6 +59,12 @@ class GraphConstructorTest : public ::testing::Test {
TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get()));
}
void ExpectVersion(int version) {
EXPECT_NE(nullptr, g_);
EXPECT_EQ(version, g_->version()) << "Expected version " << version
<< ", got " << g_->version();
}
Node* FindNode(const string& name) {
for (Node* n : g_->nodes()) {
if (n->name() == name) return n;
......@@ -160,7 +167,30 @@ TEST_F(GraphConstructorTest, TypeMismatch) {
"expected int32.");
}
TEST_F(GraphConstructorTest, EmptyGraph) { ExpectOK(""); }
TEST_F(GraphConstructorTest, EmptyGraph) {
ExpectOK("");
ExpectVersion(0); // The default GraphDef version is 0
}
TEST_F(GraphConstructorTest, VersionGraph) {
ASSERT_LT(0, TF_GRAPH_DEF_VERSION); // Verify the assertion is nontrivial
ExpectOK(strings::StrCat("version: ", TF_GRAPH_DEF_VERSION));
ExpectVersion(TF_GRAPH_DEF_VERSION);
}
TEST_F(GraphConstructorTest, LowVersion) {
ExpectError(strings::StrCat("version: ", -1),
R"(^GraphDef version -1 is no longer supported: TensorFlow \S+ )"
R"(needs \d+ <= version <= \d+\. )"
R"(Please regenerate your graph\.$)");
}
TEST_F(GraphConstructorTest, HighVersion) {
ExpectError(strings::StrCat("version: ", TF_GRAPH_DEF_VERSION_MAX + 1),
R"(^GraphDef version \d+ is not yet supported: TensorFlow \S+ )"
R"(needs \d+ <= version <= \d+\. )"
R"(Please upgrade TensorFlow\.$)");
}
TEST_F(GraphConstructorTest, SimpleModel) {
ExpectOK(
......
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/graph/graph_def_builder.h"
#include <gtest/gtest.h>
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
TEST(GraphDefBuilderTest, Version) {
RequireDefaultOps();
// Verify that our assertions will be nontrivial
ASSERT_LT(0, TF_GRAPH_DEF_VERSION);
// Newly built graphs should use the current version
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
// Check version when we convert to a Graph
Graph graph(OpRegistry::Global());
EXPECT_OK(builder.ToGraph(&graph));
ASSERT_EQ(graph.version(), TF_GRAPH_DEF_VERSION);
// Check version when we convert to a GraphDef
GraphDef graph_def;
EXPECT_OK(builder.ToGraphDef(&graph_def));
ASSERT_EQ(graph_def.version(), TF_GRAPH_DEF_VERSION);
}
} // namespace
} // namespace tensorflow
......@@ -1051,6 +1051,11 @@ Status Partition(const PartitionOptions& opts, Graph* g,
}
}
// Set versions
for (auto& it : *partitions) {
it.second.set_version(g->version());
}
// Set the start times for recvs at the very end.
if (opts.scheduling_for_recvs) {
for (auto& it : dup_recv) {
......
......@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
......@@ -72,6 +73,12 @@ void Partition(const GraphDef& graph_def,
popts.control_flow_added = false;
Status s = Partition(popts, &g, partitions);
CHECK(s.ok()) << s;
// Check versions
EXPECT_EQ(graph_def.version(), TF_GRAPH_DEF_VERSION);
for (auto& it : *partitions) {
EXPECT_EQ(graph_def.version(), it.second.version());
}
}
void CheckLoopConstruction(const GraphDef& graph_def) {
......
......@@ -36,4 +36,9 @@ limitations under the License.
// TODO(josh11b): Public API functions for exporting the above.
// Supported GraphDef versions (see graph.proto).
#define TF_GRAPH_DEF_VERSION_MIN 0
#define TF_GRAPH_DEF_VERSION_MAX 1
#define TF_GRAPH_DEF_VERSION TF_GRAPH_DEF_VERSION_MAX
#endif // THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_
......@@ -138,6 +138,7 @@ py_library(
"framework/tensor_shape.py",
"framework/dtypes.py",
"framework/tensor_util.py",
"framework/versions.py",
"ops/common_shapes.py",
],
srcs_version = "PY2AND3",
......@@ -195,6 +196,18 @@ py_test(
],
)
py_test(
name = "framework_versions_test",
srcs = ["framework/versions_test.py"],
main = "framework/versions_test.py",
srcs_version = "PY2AND3",
deps = [
":framework_test_lib",
":platform_test",
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "framework_importer_test",
srcs = ["framework/importer_test.py"],
......
......@@ -48,6 +48,7 @@ from tensorflow.core.util.event_pb2 import *
# Framework
from tensorflow.python.framework.framework_lib import *
from tensorflow.python.framework.versions import *
from tensorflow.python.framework import errors
# Session
......@@ -81,3 +82,4 @@ _whitelist = set([app, compat, errors, flags, image, logging, nn,
_whitelist.update([ops, tensor_util]) # pylint: disable=undefined-variable
__all__ = [name for name, x in locals().items() if not name.startswith('_') and
(not inspect.ismodule(x) or x in _whitelist)]
__all__.append('__version__')
......@@ -34,6 +34,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
......@@ -425,7 +426,8 @@ class SessionTest(test_util.TensorFlowTestCase):
def testGraphDef(self):
with session.Session() as sess:
self.assertProtoEquals('', sess.graph_def)
self.assertProtoEquals('version: %d' % versions.GRAPH_DEF_VERSION,
sess.graph_def)
c = constant_op.constant(5.0, name='c')
self.assertEquals(len(sess.graph_def.node), 1)
d = constant_op.constant(6.0, name='d')
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/version.h"
%}
......@@ -32,6 +33,12 @@ limitations under the License.
tensorflow::ImportNumpy();
%}
// TensorFlow version and GraphDef versions
%constant const char* __version__ = TF_VERSION_STRING;
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
%constant int GRAPH_DEF_VERSION_MIN = TF_GRAPH_DEF_VERSION_MIN;
%constant int GRAPH_DEF_VERSION_MAX = TF_GRAPH_DEF_VERSION_MAX;
// Release the Python GIL for the duration of most methods.
%exception {
Py_BEGIN_ALLOW_THREADS;
......
......@@ -215,6 +215,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
with ops.op_scope(input_map.values(), name, 'import'):
g = ops.get_default_graph()
g.graph_def_version = graph_def.version
with ops.name_scope('_inputs'):
input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
......
......@@ -111,7 +111,8 @@ for op_def in _op_list.op:
class ImportGraphDefTest(tf.test.TestCase):
def _MakeGraphDef(self, text):
def _MakeGraphDef(self, text, version=tf.GRAPH_DEF_VERSION):
text = "version: %d\n%s" % (version, text)
ret = tf.GraphDef()
text_format.Merge(text, ret)
return ret
......@@ -610,6 +611,28 @@ class ImportGraphDefTest(tf.test.TestCase):
g = tf.identity(t)
g.eval()
def testVersion(self):
for version in tf.GRAPH_DEF_VERSION_MIN, tf.GRAPH_DEF_VERSION_MAX:
with tf.Graph().as_default():
a, = tf.import_graph_def(
self._MakeGraphDef("node { name: 'A' op: 'Oii' }", version=version),
return_elements=['A'])
self.assertEqual(a.graph.graph_def_version, version)
def testVersionLow(self):
with tf.Graph().as_default():
pat = (r"^GraphDef version -1 is no longer supported: TensorFlow \S+ "
r"needs \d+ <= version <= \d+. Please regenerate your graph.$")
with self.assertRaisesRegexp(ValueError, pat):
tf.import_graph_def(self._MakeGraphDef("", version=-1))
def testVersionHigh(self):
with tf.Graph().as_default():
pat = (r"^GraphDef version \d+ is not yet supported: TensorFlow \S+ "
r"needs \d+ <= version <= \d+. Please upgrade TensorFlow.$")
with self.assertRaisesRegexp(ValueError, pat):
tf.import_graph_def(self._MakeGraphDef("", version=1 << 30))
if __name__ == '__main__':
tf.test.main()
......@@ -37,6 +37,7 @@ from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.util import compat
......@@ -1545,6 +1546,7 @@ class Graph(object):
@@seed
@@unique_name
@@version
@@graph_def_version
@@create_op
@@gradient_override_map
......@@ -1585,6 +1587,8 @@ class Graph(object):
self._finalized = False
# Functions defined in the graph
self._functions = collections.OrderedDict()
# Default GraphDef version
self._graph_def_version = versions.GRAPH_DEF_VERSION
def _check_not_finalized(self):
"""Check if the graph is finalized.
......@@ -1620,9 +1624,36 @@ class Graph(object):
@property
def version(self):
"""Returns a version number that increases as ops are added to the graph."""
"""Returns a version number that increases as ops are added to the graph.
Note that this is unrelated to the
[GraphDef version](#Graph.graph_def_version).
"""
return self._next_id_counter
@property
def graph_def_version(self):
"""The GraphDef version of this graph.
For details on the meaning of each version, see [`GraphDef`]
(https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto).
"""
return self._graph_def_version
@graph_def_version.setter
def graph_def_version(self, version):
if not (versions.GRAPH_DEF_VERSION_MIN <= version <=
versions.GRAPH_DEF_VERSION_MAX):
low = version < versions.GRAPH_DEF_VERSION_MIN
raise ValueError(
"GraphDef version %d is %s supported: TensorFlow %s needs %d <= "
"version <= %d. Please %s." %
(version, "no longer" if low else "not yet",
versions.__version__, versions.GRAPH_DEF_VERSION_MIN,
versions.GRAPH_DEF_VERSION_MAX,
"regenerate your graph" if low else "upgrade TensorFlow"))
self._graph_def_version = version
@property
def seed(self):
return self._seed
......@@ -1684,6 +1715,7 @@ class Graph(object):
ValueError: If the `graph_def` would be too large.
"""
graph = graph_pb2.GraphDef()
graph.version = self._graph_def_version
bytesize = 0
for op_id in sorted(self._nodes_by_id):
op = self._nodes_by_id[op_id]
......
......@@ -410,7 +410,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
op = g.create_op("an_op", [], [dtypes.float32])
self.assertEqual(None, op.device)
gd = g.as_graph_def()
self.assertProtoEquals("""
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op" }
""", gd)
......@@ -419,7 +419,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
with g.device("/job:worker/replica:2"):
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEquals("""
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" }
""", gd)
......@@ -430,7 +430,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
device_index=3)):
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEquals("""
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2/task:0/device:CPU:3" }
""", gd)
......@@ -443,7 +443,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEquals("""
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2" }
node { name: "an_op_1" op: "an_op"
......@@ -460,7 +460,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEquals("""
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2" }
node { name: "an_op_1" op: "an_op"
......@@ -477,7 +477,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEquals("""
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2/device:CPU:1" }
node { name: "an_op_1" op: "an_op"
......@@ -501,7 +501,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEquals("""
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/device:GPU:0" }
node { name: "an_op_1" op: "an_op"
......@@ -522,7 +522,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEquals("""
self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2/device:CPU:1" }
node { name: "an_op_1" op: "an_op" }
......
......@@ -20,6 +20,8 @@ from __future__ import print_function
import tensorflow.python.platform
from tensorflow.core.framework import tensor_shape_pb2
class Dimension(object):
"""Represents the value of one dimension in a TensorShape."""
......@@ -407,6 +409,8 @@ class TensorShape(object):
# TODO(irving): Eliminate the single integer special case.
if dims is None:
self._dims = None
elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
self._dims = [as_dimension(dim.size) for dim in dims.dim]
else:
try:
dims_iter = iter(dims)
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import tensorflow.python.platform
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
......@@ -254,6 +255,19 @@ class ShapeTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(TypeError, r"unsupported operand type"):
unknown / unknown # pylint: disable=pointless-statement
def testConvertFromProto(self):
proto = tensor_util.MakeTensorShapeProto([])
self.assertEqual(tensor_shape.TensorShape([]),
tensor_shape.TensorShape(proto))
self.assertEqual(tensor_shape.TensorShape([]),
tensor_shape.as_shape(proto))
proto = tensor_util.MakeTensorShapeProto([1, 37, 42])
self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
tensor_shape.TensorShape(proto))
self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
tensor_shape.as_shape(proto))
if __name__ == "__main__":
googletest.main()
......@@ -38,6 +38,7 @@ from tensorflow.python.client import graph_util
from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions
from tensorflow.python.platform import googletest
from tensorflow.python.platform import logging
from tensorflow.python.util.protobuf import compare
......@@ -113,6 +114,11 @@ class TensorFlowTestCase(googletest.TestCase):
type(expected_message_maybe_ascii) + " and " +
type(message))
def assertProtoEqualsVersion(self, expected, actual,
version=versions.GRAPH_DEF_VERSION):
expected = "version: %d\n%s" % (version, expected)
self.assertProtoEquals(expected, actual)
def assertStartsWith(self, actual, expected_start, msg=None):
"""Assert that actual.startswith(expected_start) is True.
......
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow versions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
from tensorflow.python import pywrap_tensorflow
__version__ = pywrap_tensorflow.__version__
GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION
GRAPH_DEF_VERSION_MIN = pywrap_tensorflow.GRAPH_DEF_VERSION_MIN
GRAPH_DEF_VERSION_MAX = pywrap_tensorflow.GRAPH_DEF_VERSION_MAX
# Make sure these symbols are exported even though one starts with _.
__all__ = ["__version__", "GRAPH_DEF_VERSION", "GRAPH_DEF_VERSION_MIN",
"GRAPH_DEF_VERSION_MAX"]
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exposed tensorflow versions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
import tensorflow as tf
class VersionTest(tf.test.TestCase):
def testVersion(self):
self.assertEqual(type(tf.__version__), str)
# This pattern will need to grow as we include alpha, builds, etc.
self.assertRegexpMatches(tf.__version__, r'^\d+\.\d+\.\d+$')
def testGraphDefVersion(self):
version = tf.GRAPH_DEF_VERSION
min = tf.GRAPH_DEF_VERSION_MIN
max = tf.GRAPH_DEF_VERSION_MAX
for v in version, min, max:
self.assertEqual(type(v), int)
self.assertLessEqual(0, min)
self.assertLessEqual(min, version)
self.assertLessEqual(version, max)
if __name__ == "__main__":
tf.test.main()
......@@ -344,6 +344,42 @@ class FIFOQueueTest(tf.test.TestCase):
self.assertAllEqual(dequeued_t.eval(), elems)
def testEnqueueWrongShape(self):
q = tf.FIFOQueue(10, (tf.int32, tf.int32), ((), (2)))
with self.assertRaises(ValueError):
q.enqueue(([1, 2], [2, 2]))
with self.assertRaises(ValueError):
q.enqueue_many((7, [[1, 2], [3, 4], [5, 6]]))
def testBatchSizeMismatch(self):
q = tf.FIFOQueue(10, (tf.int32, tf.int32, tf.int32), ((), (), ()))
with self.assertRaises(ValueError):
q.enqueue_many(([1, 2, 3], [1, 2], [1, 2, 3]))
with self.assertRaises(ValueError):
q.enqueue_many(([1, 2, 3], [1, 2], tf.placeholder(tf.int32)))
with self.assertRaises(ValueError):
q.enqueue_many((tf.placeholder(tf.int32), [1, 2], [1, 2, 3]))
def testEnqueueManyEmptyTypeConversion(self):
q = tf.FIFOQueue(10, (tf.int32, tf.float32), ((), ()))
enq = q.enqueue_many(([], []))
self.assertEqual(tf.int32, enq.inputs[1].dtype)
self.assertEqual(tf.float32, enq.inputs[2].dtype)
def testEnqueueWrongType(self):
q = tf.FIFOQueue(10, (tf.int32, tf.float32), ((), ()))
with self.assertRaises(ValueError):
q.enqueue((tf.placeholder(tf.int32), tf.placeholder(tf.int32)))
with self.assertRaises(ValueError):
q.enqueue_many((tf.placeholder(tf.int32), tf.placeholder(tf.int32)))
def testEnqueueWrongShapeAtRuntime(self):
with self.test_session() as sess:
q = tf.FIFOQueue(10, (tf.int32, tf.int32), ((2, 2), (3, 3)))
elems_ok = np.array([1] * 4).reshape((2, 2)).astype(np.int32)
......@@ -353,8 +389,6 @@ class FIFOQueueTest(tf.test.TestCase):
tf.errors.InvalidArgumentError, r"Expected \[3,3\], got \[3,4\]"):
sess.run([enqueue_op],
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
sess.run([enqueue_op],
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
def testEnqueueDequeueManyWrongShape(self):
with self.test_session() as sess:
......
......@@ -485,5 +485,74 @@ class LSTMTest(tf.test.TestCase):
self._testDoubleInputWithDropoutAndDynamicCalculation(True)
class BidirectionalRNNTest(tf.test.TestCase):
def setUp(self):
self._seed = 23489
np.random.seed(self._seed)
def _testBidirectionalRNN(self, use_gpu):
num_units = 3
input_size = 5
batch_size = 2
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
sequence_length = tf.placeholder(tf.int64)
cell_fw = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer)
cell_bw = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer)
inputs = 10 * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs = tf.nn.bidirectional_rnn(
cell_fw, cell_bw, inputs, dtype=tf.float32,
sequence_length=sequence_length)
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
self.assertEqual(out.get_shape().as_list(), [batch_size, 2 * num_units])
tf.initialize_all_variables().run()
input_value = np.random.randn(batch_size, input_size)
# Run with pre-specified sequence length of 2, 3
out = sess.run(outputs, feed_dict={inputs[0]: input_value,
sequence_length: [2, 3]})
# Since the forward and backward LSTM cells were initialized with the
# same parameters, the forward and backward output has to be the same,
# but reversed in time. The format is output[time][batch][depth], and
# due to depth concatenation (as num_units=3 for both RNNs):
# - forward output: out[][][depth] for 0 <= depth < 3
# - backward output: out[][][depth] for 4 <= depth < 6
#
# First sequence in batch is length=2
# Check that the time=0 forward output is equal to time=1 backward output
self.assertEqual(out[0][0][0], out[1][0][3])
self.assertEqual(out[0][0][1], out[1][0][4])
self.assertEqual(out[0][0][2], out[1][0][5])
# Check that the time=1 forward output is equal to time=0 backward output
self.assertEqual(out[1][0][0], out[0][0][3])
self.assertEqual(out[1][0][1], out[0][0][4])
self.assertEqual(out[1][0][2], out[0][0][5])
# Second sequence in batch is length=3
# Check that the time=0 forward output is equal to time=2 backward output
self.assertEqual(out[0][1][0], out[2][1][3])
self.assertEqual(out[0][1][1], out[2][1][4])
self.assertEqual(out[0][1][2], out[2][1][5])
# Check that the time=1 forward output is equal to time=1 backward output
self.assertEqual(out[1][1][0], out[1][1][3])
self.assertEqual(out[1][1][1], out[1][1][4])
self.assertEqual(out[1][1][2], out[1][1][5])
# Check that the time=2 forward output is equal to time=0 backward output
self.assertEqual(out[2][1][0], out[0][1][3])
self.assertEqual(out[2][1][1], out[0][1][4])
self.assertEqual(out[2][1][2], out[0][1][5])
def testBidirectionalRNN(self):
self._testBidirectionalRNN(use_gpu=False)
self._testBidirectionalRNN(use_gpu=True)
if __name__ == "__main__":
tf.test.main()
......@@ -157,6 +157,26 @@ class QueueBase(object):
"""The list of dtypes for each component of a queue element."""
return self._dtypes
def _check_enqueue_dtypes(self, vals):
"""Returns `vals` as a list of `Tensor`s, having checked their dtypes.
Args:
vals: A tensor or a list of tensors, corresponding to an
enqueue(_many) tuple.
Returns:
A list of `Tensor` objects.
"""
if not isinstance(vals, (list, tuple)):
vals = [vals]
tensors = []
for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
tensors.append(ops.convert_to_tensor(val, dtype=dtype,
name="component_%d" % i))
return tensors
def enqueue(self, vals, name=None):
"""Enqueues one element to this queue.
......@@ -170,16 +190,18 @@ class QueueBase(object):
Returns:
The operation that enqueues a new tuple of tensors to the queue.
"""
if name is None:
name = "%s_enqueue" % self._name
ret = gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=name)
if not isinstance(vals, (list, tuple)):
vals = [vals]
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
for val, shape in zip(ret.inputs[1:], self._shapes):
val.get_shape().assert_is_compatible_with(shape)
with ops.op_scope(vals, name, "%s_enqueue" % self._name) as scope:
vals = self._check_enqueue_dtypes(vals)
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
for val, shape in zip(vals, self._shapes):
val.get_shape().assert_is_compatible_with(shape)
return ret
return gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=scope)
def enqueue_many(self, vals, name=None):
"""Enqueues zero or elements to this queue.
......@@ -199,20 +221,22 @@ class QueueBase(object):
Returns:
The operation that enqueues a batch of tuples of tensors to the queue.
"""
if name is None:
name = "%s_EnqueueMany" % self._name
ret = gen_data_flow_ops._queue_enqueue_many(
self._queue_ref, vals, name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
batch_dim = ret.inputs[1].get_shape()[0]
for val, shape in zip(ret.inputs[1:], self._shapes):
batch_dim.merge_with(val.get_shape()[0])
val.get_shape()[1:].assert_is_compatible_with(shape)
return ret
if not isinstance(vals, (list, tuple)):
vals = [vals]
with ops.op_scope(vals, name, "%s_EnqueueMany" % self._name) as scope:
vals = self._check_enqueue_dtypes(vals)
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
batch_dim = vals[0].get_shape().with_rank_at_least(1)[0]
for val, shape in zip(vals, self._shapes):
batch_dim = batch_dim.merge_with(
val.get_shape().with_rank_at_least(1)[0])
val.get_shape()[1:].assert_is_compatible_with(shape)
return gen_data_flow_ops._queue_enqueue_many(
self._queue_ref, vals, name=scope)
def dequeue(self, name=None):
"""Dequeues one element from this queue.
......
......@@ -148,3 +148,92 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
outputs[-1] = array_ops.identity(outputs[-1])
return (outputs, states)
def _reverse_seq(input_seq, lengths):
"""Reverse a list of Tensors up to specified lengths.
Args:
input_seq: Sequence of seq_len tensors of dimension (batch_size, depth)
lengths: A tensor of dimension batch_size, containing lengths for each
sequence in the batch. If "None" is specified, simply reverses
the list.
Returns:
time-reversed sequence
"""
if lengths is None:
return list(reversed(input_seq))
# Join into (time, batch_size, depth)
s_joined = array_ops.pack(input_seq)
# Reverse along dimension 0
s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
# Split again into list
result = array_ops.unpack(s_reversed)
return result
def bidirectional_rnn(cell_fw, cell_bw, inputs,
initial_state_fw=None, initial_state_bw=None,
dtype=None, sequence_length=None, scope=None):
"""Creates a bidirectional recurrent neural network.
Similar to the unidirectional case above (rnn) but takes input and builds
independent forward and backward RNNs with the final forward and backward
outputs depth-concatenated, such that the output will have the format
[time][batch][cell_fw.output_size + cell_bw.output_size]. The initial state
for both directions is zero by default (but can be set optionally) and no
intermediate states are ever returned -- the network is fully unrolled for
the given (passed in) length(s) of the sequence(s).
Args:
cell_fw: An instance of RNNCell, to be used for forward direction.
cell_bw: An instance of RNNCell, to be used for backward direction.
inputs: A length T list of inputs, each a vector with shape [batch_size].
initial_state_fw: (optional) An initial state for the forward RNN.
This must be a tensor of appropriate type and shape
[batch_size x cell.state_size].
initial_state_bw: (optional) Same as for initial_state_fw.
dtype: (optional) The data type for the initial state. Required if either
of the initial states are not provided.
sequence_length: An int64 vector (tensor) of size [batch_size], containing
the actual lengths for each of the sequences.
scope: VariableScope for the created subgraph; defaults to "BiRNN"
Returns:
A set of output `Tensors` where:
outputs is a length T list of outputs (one for each input), which
are depth-concatenated forward and backward outputs
Raises:
TypeError: If "cell_fw" or "cell_bw" is not an instance of RNNCell.
ValueError: If inputs is None or an empty list.
ValueError: If sequence_length is not defined.
"""
if not isinstance(cell_fw, rnn_cell.RNNCell):
raise TypeError("cell_fw must be an instance of RNNCell")
if not isinstance(cell_bw, rnn_cell.RNNCell):
raise TypeError("cell_bw must be an instance of RNNCell")
if not isinstance(inputs, list):
raise TypeError("inputs must be a list")
if not sequence_length:
raise ValueError("sequence_length has to be defined")
if not inputs:
raise ValueError("inputs must not be empty")
name = scope or "BiRNN"
# Forward direction
with vs.variable_scope(name + "_FW"):
output_fw, _ = rnn(cell_fw, inputs, initial_state_fw, dtype)
# Backward direction
with vs.variable_scope(name + "_BW"):
tmp, _ = rnn(
cell_bw, _reverse_seq(inputs, sequence_length), initial_state_bw, dtype)
output_bw = _reverse_seq(tmp, sequence_length)
# Concat each of the forward/backward outputs
outputs = [array_ops.concat(1, [fw, bw])
for fw, bw in zip(output_fw, output_bw)]
return outputs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册