提交 1b2a6d9d 编写于 作者: A Andiry Xu 提交者: TensorFlower Gardener

Do not update unknown Placeholder with empty shape.

Replace unknown shape with empty shape causes incompatiblity between inferred shape and actual (annotated) shape. Also factor out UpdatePlaceholderShape for readability.

PiperOrigin-RevId: 251288300
上级 026bcc0c
......@@ -107,6 +107,99 @@ Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
}
// Replace unknown dimensions in Placeholder shape if
// cfg.placeholder_unknown_output_shape_dim is set or
// the Placeholder node has _output_shapes.
// Otherwise keep it intact to keep compatible with shape annotation
// (b/134092018).
Status UpdatePlaceholderShape(
const ItemConfig& cfg,
const std::unordered_set<string>& signature_feed_nodes,
GrapplerItem* new_item, NodeDef* node) {
if (node->attr().count("dtype") == 0) {
return errors::Internal("Unknown type for placeholder ", node->name(),
", skipping this input");
}
DataType type = node->attr().at("dtype").type();
// TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and
// _output_shapes is present case.
if (node->attr().count("shape") == 0) {
return errors::Internal("Unknown shape for placeholder ", node->name(),
", skipping this input");
}
// Replace all unknown dimensions in the placeholder's tensorshape proto
// with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
// from it. We do this because in newer protos, the input placeholder
// shape is not empty if the shape is partially defined.
TensorShape shape;
TensorShapeProto shape_proto;
Status make_shape_status = ReplaceUnknownShapeDim(
cfg, node->attr().at("shape").shape(), &shape_proto, &shape);
if (!make_shape_status.ok()) {
return errors::Internal("Invalid shape for placeholder ", node->name(),
": ", make_shape_status, ", skipping this input");
}
// Some placeholder nodes have a mis-match between the node
// attribute "shape" and a different node attribute "_output_shapes".
// Specifically, a shape with shape.dims() == 0 could indicate either
// a scalar or an unknown shape. In those cases, we check _output_shapes
// for additional information.
// This case is observed in the bnmt graphs. Have not observed any
// cases where there was more than 1 _output_shapes, so limit it
// to cases where there is only 1 _output_shapes.
// We only do this if cfg.placeholder_unknown_output_shape_dim has
// been set to avoid crashing non-BNMT graphs.
// TODO(andiryxu): Investigate if this is a bug in BNMT graph.
if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) &&
(node->attr().count("_output_shapes") == 1)) {
const auto& output_shapes =
node->attr().at("_output_shapes").list().shape(0);
if (output_shapes.dim_size() != 0) {
shape.Clear();
shape_proto.clear_dim();
for (const auto& dim : output_shapes.dim()) {
auto size = dim.size();
if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
shape.AddDim(size);
shape_proto.add_dim()->set_size(size);
}
}
}
Tensor fake_input(type, shape);
InitializeTensor(type, &fake_input);
if (cfg.feed_nodes.empty()) {
// No specific feed nodes were given. Assume all placeholders are fed.
if (signature_feed_nodes.count(node->name()) == 0) {
new_item->feed.emplace_back(node->name(), fake_input);
}
} else if (cfg.feed_nodes.count(node->name()) > 0) {
// If specific feed nodes were given, only update their tensors.
auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
[&node](std::pair<string, Tensor>& f) {
return f.first == node->name();
});
DCHECK(it != new_item->feed.end());
it->second = fake_input;
}
// Set the shape of the node in the graph. This is needed for statically
// inferring shapes and is a no-op when dynamically inferring shapes as
// the Placeholder shape will match the shape passed from new_item->feed.
// Only replace node shape with known shape. For unknown shape keep it intact
// (b/134092018).
if (!shape_proto.dim().empty())
*(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto;
return Status::OK();
}
} // namespace
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
......@@ -439,83 +532,9 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
for (auto& node : *new_item->graph.mutable_node()) {
if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
if (node.attr().count("dtype") == 0) {
LOG(ERROR) << "Unknown type for placeholder " << node.name()
<< ", skipping this input";
return nullptr;
}
DataType type = node.attr().at("dtype").type();
if (node.attr().count("shape") == 0) {
LOG(INFO) << "Unknown shape for placeholder " << node.name()
<< ", skipping this input";
return nullptr;
}
// Replace all unknown dimensions in the placeholder's tensorshape proto
// with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
// from it. We do this because in newer protos, the input placeholder
// shape is not empty if the shape is partially defined.
TensorShape shape;
TensorShapeProto shape_proto;
Status make_shape_status = ReplaceUnknownShapeDim(
cfg, node.attr().at("shape").shape(), &shape_proto, &shape);
if (!make_shape_status.ok()) {
LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": "
<< make_shape_status << ", skipping this input";
return nullptr;
}
// Some placeholder nodes have a mis-match between the node
// attribute "shape" and a different node attribute "_output_shapes".
// Specifically, a shape with shape.dims() == 0 could indicate either
// a scalar or an unknown shape. In those cases, we check _output_shapes
// for additional information.
// This case is observed in the bnmt graphs. Have not observed any
// cases where there was more than 1 _output_shapes, so limit it
// to cases where there is only 1 _output_shapes.
// We only do this if cfg.placeholder_unknown_output_shape_dim has
// been set to avoid crashing non-BNMT graphs.
if ((cfg.placeholder_unknown_output_shape_dim >= 0) &&
(shape.dims() == 0) && (node.attr().count("_output_shapes") == 1)) {
const auto& output_shapes =
node.attr().at("_output_shapes").list().shape(0);
if (output_shapes.dim_size() != 0) {
shape.Clear();
shape_proto.clear_dim();
for (const auto& dim : output_shapes.dim()) {
auto size = dim.size();
if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
shape.AddDim(size);
shape_proto.add_dim()->set_size(size);
}
}
}
Tensor fake_input(type, shape);
InitializeTensor(type, &fake_input);
if (cfg.feed_nodes.empty()) {
// No specific feed nodes were given. Assume all placeholders are fed.
if (signature_feed_nodes.count(node.name()) == 0) {
new_item->feed.emplace_back(node.name(), fake_input);
}
} else if (cfg.feed_nodes.count(node.name()) > 0) {
// If specific feed nodes were given, only update their tensors.
auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
[&node](std::pair<string, Tensor>& f) {
return f.first == node.name();
});
QCHECK(it != new_item->feed.end());
it->second = fake_input;
}
// Set the shape of the node in the graph. This is needed for statically
// inferring shapes and is a no-op when dynamically inferring shapes as
// the Placeholder shape will match the shape passed from new_item->feed.
*(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto;
Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
new_item.get(), &node);
if (!s.ok()) return nullptr;
} else if (IsConstant(node)) {
auto it = asset_node_to_value.find(node.name());
if (it != asset_node_to_value.end()) {
......
......@@ -336,6 +336,149 @@ TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) {
EXPECT_EQ(item->fetch[0], "z");
}
TEST_F(GrapplerItemBuilderTest, UnknownRankPlaceholderTest) {
MetaGraphDef meta_graph;
const char* text_proto = R"EOF(
graph_def {
node {
name: "x"
op: "Placeholder"
attr { key: "dtype" value { type: DT_FLOAT } }
attr { key: "shape" value { shape { unknown_rank: true } } }
}
versions {
producer: 51
}
}
collection_def {
key: "train_op"
value {
node_list {
value: "x:0"
}
}
}
)EOF";
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
ItemConfig cfg;
std::unique_ptr<GrapplerItem> item =
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
ASSERT_TRUE(item != nullptr);
const NodeDef& node = item->graph.node(0);
const auto iter = node.attr().find("shape");
ASSERT_TRUE(iter != node.attr().end());
ASSERT_TRUE(iter->second.has_shape());
const auto& shape = iter->second.shape();
// Do not update unknown shape.
EXPECT_TRUE(shape.unknown_rank());
}
TEST_F(GrapplerItemBuilderTest, ConfigPlaceholderTest) {
MetaGraphDef meta_graph;
const char* text_proto = R"EOF(
graph_def {
node {
name: "x"
op: "Placeholder"
attr { key: "dtype" value { type: DT_FLOAT } }
attr { key: "shape" value {
shape {
dim {
size: -1
}
dim {
size: -1
}
}
} }
}
versions {
producer: 51
}
}
collection_def {
key: "train_op"
value {
node_list {
value: "x:0"
}
}
}
)EOF";
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
ItemConfig cfg;
cfg.placeholder_unknown_output_shape_dim = 64;
std::unique_ptr<GrapplerItem> item =
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
ASSERT_TRUE(item != nullptr);
const NodeDef& node = item->graph.node(0);
const auto iter = node.attr().find("shape");
ASSERT_TRUE(iter != node.attr().end());
ASSERT_TRUE(iter->second.has_shape());
const auto& shape = iter->second.shape();
EXPECT_EQ(shape.dim_size(), 2);
// Shape updated with placeholder_unknown_output_shape_dim.
EXPECT_EQ(shape.dim(0).size(), 64);
EXPECT_EQ(shape.dim(1).size(), 64);
}
TEST_F(GrapplerItemBuilderTest, OutputShapePlaceholderTest) {
MetaGraphDef meta_graph;
const char* text_proto = R"EOF(
graph_def {
node {
name: "x"
op: "Placeholder"
attr { key: "dtype" value { type: DT_FLOAT } }
attr { key: "shape" value { shape { unknown_rank: true } } }
attr { key: "_output_shapes" value { list {
shape {
dim {
size: -1
}
dim {
size: 32
}
}
} } }
}
versions {
producer: 51
}
}
collection_def {
key: "train_op"
value {
node_list {
value: "x:0"
}
}
}
)EOF";
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
ItemConfig cfg;
cfg.placeholder_unknown_output_shape_dim = 64;
std::unique_ptr<GrapplerItem> item =
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
ASSERT_TRUE(item != nullptr);
const NodeDef& node = item->graph.node(0);
const auto iter = node.attr().find("shape");
ASSERT_TRUE(iter != node.attr().end());
ASSERT_TRUE(iter->second.has_shape());
const auto& shape = iter->second.shape();
EXPECT_EQ(shape.dim_size(), 2);
// Shape updated with placeholder_unknown_output_shape_dim
// and _output_shapes attr.
EXPECT_EQ(shape.dim(0).size(), 64);
EXPECT_EQ(shape.dim(1).size(), 32);
}
} // namespace
} // namespace grappler
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册