提交 1c540341 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

For C++ shape inference, add InferenceContext, which can be used to implement

an op's shape inference function. Include functions for asserting shape rank, dim value, and
for merging shapes and dims.
Change: 123934214
上级 9963a7b5
......@@ -282,6 +282,7 @@ tf_cuda_library(
"framework/resource_mgr.h",
"framework/selective_registration.h",
"framework/session_state.h",
"framework/shape_inference.h",
"framework/tensor.h",
"framework/tensor_shape.h",
"framework/tensor_slice.h",
......@@ -927,6 +928,7 @@ filegroup(
"framework/rendezvous.h",
"framework/selective_registration.h",
"framework/session_state.h",
"framework/shape_inference.h",
"framework/tensor.h",
"framework/tensor_reference.h",
"framework/tensor_shape.h",
......
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace shape_inference {
constexpr int32 InferenceContext::kUnknownRank;
constexpr int64 InferenceContext::kUnknownDim;
InferenceContext::InferenceContext(const std::vector<string>& input_shapes,
int num_outputs) {
for (const string& spec : input_shapes) {
if (spec == "?") {
inputs_.push_back(CreateUnknownShape());
} else {
std::vector<const Dimension*> dims;
strings::Scanner scanner(spec);
scanner.OneLiteral("[");
while (scanner.Peek() != ']') {
if (scanner.Peek() == '?') {
scanner.OneLiteral("?");
dims.push_back(CreateUnknownDim());
} else {
scanner.RestartCapture().Many(strings::Scanner::DIGIT);
StringPiece match;
int64 dim_size = 0;
CHECK(scanner.GetResult(nullptr, &match) &&
strings::safe_strto64(match, &dim_size))
<< spec;
dims.push_back(CreateDim(dim_size));
}
if (scanner.Peek() == ',') {
scanner.OneLiteral(",");
} else {
CHECK_EQ(scanner.Peek(), ']');
}
}
CHECK(scanner.OneLiteral("]").Eos().GetResult()) << spec;
inputs_.push_back(CreateShape(dims));
}
}
for (int i = 0; i < num_outputs; ++i) {
outputs_.push_back(CreateUnknownShape());
}
}
InferenceContext::~InferenceContext() {
for (auto* s : all_shapes_) delete s;
for (auto* d : all_dims_) delete d;
}
string InferenceContext::DebugString(const Shape* s) {
if (RankKnown(s)) {
std::vector<string> vals;
for (auto d : s->dims_) vals.push_back(DebugString(d));
return strings::StrCat("[", str_util::Join(vals, ","), "]");
} else {
return "?";
}
}
string InferenceContext::DebugString(const Dimension* d) {
return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
}
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
// the shape with asserted rank in <*out>. Otherwise return an error.
Status InferenceContext::WithRank(const Shape* shape, int32 rank,
const Shape** out) {
const int32 existing = Rank(shape);
if (existing == rank) {
*out = shape;
return Status::OK();
}
if (existing == kUnknownRank) {
std::vector<const Dimension*> dims;
dims.reserve(rank);
for (int i = 0; i < rank; ++i) {
all_dims_.push_back(new Dimension());
dims.push_back(all_dims_.back());
}
all_shapes_.push_back(new Shape(dims));
*out = all_shapes_.back();
return Status::OK();
}
*out = nullptr;
return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
existing);
}
Status InferenceContext::WithValue(const Dimension* dim, int64 value,
const Dimension** out) {
const int64 existing = Value(dim);
if (existing == value) {
*out = dim;
return Status::OK();
}
if (existing == kUnknownDim) {
all_dims_.push_back(new Dimension(value));
*out = all_dims_.back();
return Status::OK();
}
*out = nullptr;
return errors::InvalidArgument("Dimension must be size ", value,
" but is size ", existing);
}
Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1,
const Dimension** out) {
if (d0 == d1 || !ValueKnown(d1)) {
*out = d0;
return Status::OK();
} else if (!ValueKnown(d0)) {
*out = d1;
return Status::OK();
} else if (Value(d0) == Value(d1)) {
*out = d0;
return Status::OK();
} else {
*out = nullptr;
return errors::InvalidArgument("Dimensions must be equal size, but are ",
Value(d0), " and ", Value(d1));
}
}
Status InferenceContext::Merge(const Shape* s0, const Shape* s1,
const Shape** out) {
if (s0 == s1 || !RankKnown(s1)) {
*out = s0;
return Status::OK();
} else if (!RankKnown(s0)) {
*out = s1;
return Status::OK();
}
const int32 rank = Rank(s0);
if (rank != Rank(s1)) {
*out = nullptr;
return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
" and ", Rank(s1));
}
bool return_s0 = true;
bool return_s1 = true;
for (int i = 0; i < rank; ++i) {
auto d0 = Dim(s0, i);
auto d1 = Dim(s1, i);
if (d0 == d1) continue;
auto v0 = Value(d0);
auto v1 = Value(d1);
if (v0 == kUnknownDim) {
if (v1 != kUnknownDim) {
return_s0 = false;
}
} else if (v1 == kUnknownDim) {
return_s1 = false;
} else if (v0 != v1) {
*out = nullptr;
return errors::InvalidArgument("Dimensions must be equal size, but are ",
Value(d0), " and ", Value(d1));
}
}
if (return_s0 || return_s1) {
*out = return_s0 ? s0 : s1;
return Status::OK();
}
// Merge dims.
std::vector<const Dimension*> dims(rank, nullptr);
for (int i = 0; i < rank; ++i) {
// Invariant for merge was checked earlier, so CHECK is ok.
TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
}
*out = CreateShape(dims);
return Status::OK();
}
const Shape* InferenceContext::CreateShape(
const std::vector<const Dimension*>& dims) {
all_shapes_.push_back(new Shape(dims));
return all_shapes_.back();
}
const Shape* InferenceContext::CreateUnknownShape() {
all_shapes_.push_back(new Shape());
return all_shapes_.back();
}
const Dimension* InferenceContext::CreateDim(int64 value) {
all_dims_.push_back(new Dimension(value));
return all_dims_.back();
}
const Dimension* InferenceContext::CreateUnknownDim() {
all_dims_.push_back(new Dimension());
return all_dims_.back();
}
} // namespace shape_inference
} // namespace tensorflow
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace shape_inference {
class InferenceContext;
// Dimension values are accessed through InferenceContext.
class Dimension {
private:
Dimension();
Dimension(int64 value);
~Dimension() {}
const int64 value_;
friend class InferenceContext;
TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
};
// Shape rank and dimensions are accessed through InferenceContext.
class Shape {
private:
Shape();
Shape(std::vector<const Dimension*> dims);
~Shape() {}
const int32 rank_;
const std::vector<const Dimension*> dims_;
friend class InferenceContext;
TF_DISALLOW_COPY_AND_ASSIGN(Shape);
};
// Note: This is experimental support for op shape inference in C++. Shape
// inference functions are not ready to be implemented yet.
//
// An InferenceContext is created by the framework and passed to a shape
// inference function. The shape inference function calls functions on the
// context, and should call set_output() to set the shape on all outputs.
//
// All Shape* and Dimension* returned by functions of InferenceContext are owned
// by the InferenceContext.
class InferenceContext {
public:
static constexpr int32 kUnknownRank = -1;
static constexpr int64 kUnknownDim = -1;
// This is a temporary constructor used for initial testing.
//
// TODO(cwhipkey): remove this temporary constructor.
//
// Each input shape describes the input shape as follows:
// * "?" : the shape's rank and dimensions are unknown
// * "[1,?,3]" : the shape's rank is known, and dimensions can be known or
// unknown (? for unknown #1 - multiple dimensions can be
// labeled with the same unknown number, and are deduplicated to
// the same Dimension*.
InferenceContext(const std::vector<string>& input_shapes, int num_outputs);
~InferenceContext();
const Shape* input(int idx) const { return inputs_[idx]; }
int num_inputs() const { return inputs_.size(); }
void set_output(int idx, const Shape* shape);
int num_outputs() const { return outputs_.size(); }
// idx can be negative for an offset from end of dimensions.
const Dimension* Dim(const Shape* s, int32 idx) { return s->dims_[idx]; }
int32 Rank(const Shape* s) { return s->rank_; }
bool RankKnown(const Shape* s) { return Rank(s) != kUnknownRank; }
int64 Value(const Dimension* d) { return d->value_; }
bool ValueKnown(const Dimension* d) { return Value(d) != kUnknownDim; }
string DebugString(const Shape* s);
string DebugString(const Dimension* d);
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
// the shape with asserted rank in <*out>. Otherwise return an error.
//
// Note that <*out> may be set to <shape>.
Status WithRank(const Shape* shape, int32 rank,
const Shape** out) TF_MUST_USE_RESULT;
// If <dim> has value <value>, or its value is unknown, returns OK and returns
// the dimension with asserted value in <*out>. Otherwise returns an error.
//
// Note that <*out> may be set to <dim>.
Status WithValue(const Dimension* dim, int64 value,
const Dimension** out) TF_MUST_USE_RESULT;
// Merges <in0> and <in1> and returns the merged shape in <*out>. If <in0> and
// <in1> are incompatible in rank, or in the value of any dimension, returns
// an error.
//
// Note that <*out> may be set to <in0> or <in1>.
Status Merge(const Shape* in0, const Shape* in1,
const Shape** out) TF_MUST_USE_RESULT;
// Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
// and <d1> have incompatible values, returns an error.
//
// Note that <*out> may be set to <d0> or <d1>.
Status Merge(const Dimension* d0, const Dimension* d1,
const Dimension** out) TF_MUST_USE_RESULT;
// Returns a new shape with the given dims. The returned value is owned by
// this context.
const Shape* CreateShape(const std::vector<const Dimension*>& dims);
const Shape* CreateUnknownShape();
// Returns a new dimension of the given size. The returned value is owned by
// this context.
const Dimension* CreateDim(int64 value);
const Dimension* CreateUnknownDim();
private:
std::vector<Shape*> all_shapes_; // values are owned.
std::vector<Dimension*> all_dims_; // values are owned.
// inputs_ and outputs_ refer to values from all_shapes_.
std::vector<const Shape*> inputs_;
std::vector<const Shape*> outputs_;
TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
};
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
inline Dimension::Dimension(int64 value) : value_(value) {}
inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
inline Shape::Shape(const std::vector<const Dimension*> dims)
: rank_(dims.size()), dims_(dims) {}
} // namespace shape_inference
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace shape_inference {
TEST(ShapeInferenceTest, RankAndDimInspection) {
InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
auto in0 = c.input(0);
EXPECT_EQ("?", c.DebugString(in0));
EXPECT_FALSE(c.RankKnown(in0));
EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0));
auto in1 = c.input(1);
EXPECT_EQ("[1,?,3]", c.DebugString(in1));
EXPECT_TRUE(c.RankKnown(in1));
EXPECT_EQ(3, c.Rank(in1));
auto d = c.Dim(in1, 0);
EXPECT_EQ(1, c.Value(d));
EXPECT_TRUE(c.ValueKnown(d));
EXPECT_EQ("1", c.DebugString(d));
d = c.Dim(in1, 1);
EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d));
EXPECT_FALSE(c.ValueKnown(d));
EXPECT_EQ("?", c.DebugString(d));
d = c.Dim(in1, 2);
EXPECT_EQ(3, c.Value(d));
EXPECT_TRUE(c.ValueKnown(d));
EXPECT_EQ("3", c.DebugString(d));
auto in2 = c.input(2);
EXPECT_EQ("[]", c.DebugString(in2));
EXPECT_TRUE(c.RankKnown(in2));
EXPECT_EQ(0, c.Rank(in2));
}
TEST(ShapeInferenceTest, WithRank) {
InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
const Shape* s1 = nullptr;
const Shape* s2 = nullptr;
// WithRank on a shape with unknown dimensionality always succeeds.
EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok());
EXPECT_EQ("[?]", c.DebugString(s1));
EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok());
EXPECT_EQ("[?,?]", c.DebugString(s2));
EXPECT_TRUE(s1 != s2); // different pointers
EXPECT_TRUE(c.Dim(s2, 0) != c.Dim(s2, 1)); // different pointers.
EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok());
EXPECT_EQ("[?]", c.DebugString(s2));
EXPECT_TRUE(s1 != s2); // different pointers
EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok());
EXPECT_EQ("[]", c.DebugString(s1));
// WithRank on shape with known dimensionality.
s1 = in1;
EXPECT_EQ("Invalid argument: Shape must be rank 2 but is rank 3",
c.WithRank(in1, 2, &s1).ToString());
EXPECT_TRUE(s1 == nullptr);
EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok());
EXPECT_TRUE(s1 == in1); // same pointers
// Inputs are unchanged.
EXPECT_EQ("?", c.DebugString(in0));
EXPECT_EQ("[1,?,3]", c.DebugString(in1));
}
TEST(ShapeInferenceTest, WithValue) {
InferenceContext c({"[1,?]"}, 2 /* num_outputs */);
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
const Dimension* out1 = nullptr;
const Dimension* out2 = nullptr;
// WithRank on a dimension with unknown value always succeeds.
EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok());
EXPECT_EQ(1, c.Value(out1));
EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok());
EXPECT_EQ(2, c.Value(out2));
EXPECT_TRUE(out1 != out2); // different pointers
EXPECT_TRUE(out1 != d1); // different pointers
EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok());
EXPECT_EQ(1, c.Value(out2));
EXPECT_TRUE(out1 != out2); // different pointers
// WithRank on dimension with known size.
out1 = d0;
EXPECT_EQ("Invalid argument: Dimension must be size 0 but is size 1",
c.WithValue(d0, 0, &out1).ToString());
EXPECT_TRUE(out1 == nullptr);
out1 = d0;
EXPECT_EQ("Invalid argument: Dimension must be size 2 but is size 1",
c.WithValue(d0, 2, &out1).ToString());
EXPECT_TRUE(out1 == nullptr);
EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok());
EXPECT_TRUE(d0 == out1); // same pointers
// Inputs are unchanged.
EXPECT_EQ("1", c.DebugString(d0));
EXPECT_EQ("?", c.DebugString(d1));
}
TEST(ShapeInferenceTest, MergeDim) {
InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */);
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
auto d2_b = c.Dim(c.input(0), 2);
auto d1 = c.Dim(c.input(0), 3);
auto d_unknown_b = c.Dim(c.input(0), 4);
const Dimension* out = nullptr;
// Merging anything with unknown returns the same pointer.
EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok());
EXPECT_TRUE(d2 == out);
EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok());
EXPECT_TRUE(d2 == out);
EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok());
EXPECT_TRUE(d_unknown == out);
// Merging with self returns self.
EXPECT_TRUE(c.Merge(d2, d2, &out).ok());
EXPECT_TRUE(d2 == out);
EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok());
EXPECT_TRUE(d_unknown == out);
// Merging equal values returns first one.
EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok());
EXPECT_TRUE(d2 == out);
EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok());
EXPECT_TRUE(d2_b == out);
// Merging inequal values is an error.
EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 2 and 1",
c.Merge(d2, d1, &out).ToString());
EXPECT_TRUE(out == nullptr);
EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 1 and 2",
c.Merge(d1, d2, &out).ToString());
EXPECT_TRUE(out == nullptr);
}
TEST(ShapeInferenceTest, MergeShape) {
InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
2 /* num_outputs */);
auto s_unknown = c.input(0);
auto s_1_2 = c.input(1);
auto s_u_2 = c.input(2);
auto s_1_u = c.input(3);
auto s_1_3 = c.input(4);
auto s_unknown_b = c.input(5);
auto s_1 = c.input(6);
const Shape* out = nullptr;
// Merging any shape with unknown returns the shape.
EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok());
EXPECT_TRUE(s_1_2 == out);
EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok());
EXPECT_TRUE(s_u_2 == out);
EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok());
EXPECT_TRUE(s_unknown == out);
// Merging with self returns self.
EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok());
EXPECT_TRUE(out == s_1_2);
// Merging where one of the inputs is the right answer - return that input.
out = nullptr;
EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok());
EXPECT_TRUE(s_1_2 == out);
out = nullptr;
EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok());
EXPECT_TRUE(s_1_2 == out);
// Merging where neither input is the right answer.
EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok());
EXPECT_TRUE(out != s_u_2);
EXPECT_TRUE(out != s_1_u);
EXPECT_EQ("[1,2]", c.DebugString(out));
EXPECT_TRUE(c.Dim(s_1_u, 0) == c.Dim(out, 0)); // same pointers
EXPECT_TRUE(c.Dim(s_u_2, 1) == c.Dim(out, 1)); // same pointers
// Incompatible merges give errors and set out to nullptr.
out = s_unknown;
EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 2 and 3",
c.Merge(s_u_2, s_1_3, &out).ToString());
EXPECT_TRUE(out == nullptr);
out = s_unknown;
EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 3 and 2",
c.Merge(s_1_3, s_u_2, &out).ToString());
EXPECT_TRUE(out == nullptr);
out = s_unknown;
EXPECT_EQ("Invalid argument: Shapes must be equal rank, but are 1 and 2",
c.Merge(s_1, s_1_2, &out).ToString());
EXPECT_TRUE(out == nullptr);
}
TEST(ShapeInferenceTest, CreateShape) {
InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */);
std::vector<const Dimension*> dims;
auto in0 = c.input(0);
const int rank = c.Rank(in0);
for (int i = 0; i < rank; ++i) {
dims.push_back(c.Dim(in0, rank - i - 1));
}
auto s = c.CreateShape(dims);
EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s));
EXPECT_TRUE(c.Dim(s, 0) == c.Dim(in0, rank - 1));
auto s2 = c.CreateShape(dims);
EXPECT_TRUE(s != s2); // different pointers
EXPECT_TRUE(c.Dim(s2, 0) == c.Dim(in0, rank - 1));
}
TEST(ShapeInferenceTest, CreateUnknownShape) {
InferenceContext c({}, 2 /* num_outputs */);
auto u0 = c.CreateUnknownShape();
auto u1 = c.CreateUnknownShape();
EXPECT_EQ("?", c.DebugString(u0));
EXPECT_EQ("?", c.DebugString(u1));
EXPECT_TRUE(u0 != u1); // different pointers
}
TEST(ShapeInferenceTest, CreateDim) {
InferenceContext c({}, 2 /* num_outputs */);
auto* d0 = c.CreateDim(1);
auto* d1 = c.CreateDim(1);
auto* d2 = c.CreateDim(2);
EXPECT_EQ("1", c.DebugString(d0));
EXPECT_EQ("1", c.DebugString(d1));
EXPECT_TRUE(d0 != d1); // different pointers
EXPECT_EQ("2", c.DebugString(d2));
}
TEST(ShapeInferenceTest, CreateUnknownDim) {
InferenceContext c({}, 2 /* num_outputs */);
auto* d0 = c.CreateUnknownDim();
auto* d1 = c.CreateUnknownDim();
EXPECT_EQ("?", c.DebugString(d0));
EXPECT_EQ("?", c.DebugString(d1));
EXPECT_TRUE(d0 != d1); // different pointers
}
} // namespace shape_inference
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册