提交 c092f31c 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[Tensorflow] Expose API to lookup TensorSlice.

Change: 150384503
上级 433c8c89
......@@ -712,6 +712,18 @@ Status BundleReader::Lookup(StringPiece key, Tensor* val) {
}
}
Status BundleReader::LookupTensorSlices(StringPiece key,
std::vector<TensorSlice>* slices) {
slices->clear();
BundleEntryProto entry;
TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
slices->reserve(entry.slices_size());
for (const auto& slice : entry.slices()) {
slices->emplace_back(slice);
}
return Status::OK();
}
Status BundleReader::LookupSlice(StringPiece full_tensor_key,
const TensorSlice& slice_spec, Tensor* val) {
BundleEntryProto entry;
......
......@@ -207,6 +207,12 @@ class BundleReader {
// REQUIRES: status().ok()
Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT;
// Looks up the slices of the tensor keyed by "key". On OK, "slices"
// is non-empty if and only if the tensor is a partitioned tensor.
// REQUIRES: status().ok()
Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices)
TF_MUST_USE_RESULT;
// Looks up a specific slice of a partitioned tensor.
// It is only required that the stored slices cover the requested slice,
// namely "slice_spec" is a subset of the union of the stored slices.
......
......@@ -245,15 +245,14 @@ TEST(TensorBundleTest, PartitionedVariables) {
// Adds two slices.
// First slice: column 0, all zeros.
// Second slice: column 1 to rest, all ones.
TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
{
BundleWriter writer(Env::Default(), Prefix("foo"));
TensorSlice slice = TensorSlice::ParseOrDie("-:0,1");
TF_ASSERT_OK(writer.AddSlice("foo", kFullShape,
TensorSlice::ParseOrDie("-:0,1"),
TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
Constant<float>(0., TensorShape({5, 1}))));
TF_ASSERT_OK(writer.AddSlice("foo", kFullShape,
TensorSlice::ParseOrDie("-:1,9"),
TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
Constant<float>(1., TensorShape({5, 9}))));
TF_ASSERT_OK(writer.Finish());
}
......@@ -274,6 +273,18 @@ TEST(TensorBundleTest, PartitionedVariables) {
TF_ASSERT_OK(reader.Lookup("foo", &val));
test::ExpectTensorEqual<float>(val, expected_val);
}
// Reads all slices.
{
BundleReader reader(Env::Default(), Prefix("foo"));
TF_ASSERT_OK(reader.status());
std::vector<TensorSlice> slices;
TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
EXPECT_EQ(2, slices.size());
EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
}
// Reads a slice consisting of first two columns, "cutting" both slices.
{
BundleReader reader(Env::Default(), Prefix("foo"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册