提交 e5353c94 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Don't prune nodes that have reference inputs.

PiperOrigin-RevId: 163390862
上级 22651083
...@@ -134,6 +134,7 @@ cc_library( ...@@ -134,6 +134,7 @@ cc_library(
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
......
...@@ -18,6 +18,8 @@ limitations under the License. ...@@ -18,6 +18,8 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
...@@ -26,8 +28,24 @@ namespace tensorflow { ...@@ -26,8 +28,24 @@ namespace tensorflow {
namespace grappler { namespace grappler {
GraphRewriter::GraphRewriter(const GrapplerItem& item) { GraphRewriter::GraphRewriter(const GrapplerItem& item) {
OpRegistryInterface* op_registry = OpRegistry::Global();
for (auto& node : item.graph.node()) { for (auto& node : item.graph.node()) {
nodes_[node.name()] = &node; NodeInfo* info = new NodeInfo();
info->def = &node;
const OpRegistrationData* op_reg_data = nullptr;
Status s = op_registry->LookUp(node.op(), &op_reg_data);
// TODO(bsteiner): make this not a best-effort lookup and evaluation?
if (s.ok()) {
s = InOutTypesForNode(node, op_reg_data->op_def, &info->inputs,
&info->outputs);
if (!s.ok()) {
info->inputs.clear();
info->outputs.clear();
}
}
nodes_[node.name()].reset(info);
} }
std::unordered_set<string> function_names; std::unordered_set<string> function_names;
...@@ -73,11 +91,16 @@ bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const { ...@@ -73,11 +91,16 @@ bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const {
return cross_device_receivers_.find(&node) != cross_device_receivers_.end(); return cross_device_receivers_.find(&node) != cross_device_receivers_.end();
} }
bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const {
return ref_receivers_.find(&node) != ref_receivers_.end();
}
void GraphRewriter::RecordConnectivity( void GraphRewriter::RecordConnectivity(
const NodeDef& node, const std::unordered_set<string>& function_names) { const NodeDef& node, const std::unordered_set<string>& function_names) {
const bool is_function = const bool is_function =
function_names.find(node.op()) != function_names.end(); function_names.find(node.op()) != function_names.end();
bool ref_receiver = false;
for (const auto& input : node.input()) { for (const auto& input : node.input()) {
int position = 0; int position = 0;
string input_node_name = ParseNodeName(input, &position); string input_node_name = ParseNodeName(input, &position);
...@@ -85,7 +108,8 @@ void GraphRewriter::RecordConnectivity( ...@@ -85,7 +108,8 @@ void GraphRewriter::RecordConnectivity(
if (itr == nodes_.end()) { if (itr == nodes_.end()) {
continue; continue;
} }
const NodeDef* fanin = itr->second; const NodeInfo* fanin_info = itr->second.get();
const NodeDef* fanin = fanin_info->def;
if (position < 0) { if (position < 0) {
// This is a control edge // This is a control edge
control_dependency_drivers_.insert(fanin); control_dependency_drivers_.insert(fanin);
...@@ -97,11 +121,20 @@ void GraphRewriter::RecordConnectivity( ...@@ -97,11 +121,20 @@ void GraphRewriter::RecordConnectivity(
if (is_function) { if (is_function) {
function_neighbors_.insert(fanin); function_neighbors_.insert(fanin);
} }
if (position < fanin_info->outputs.size() &&
IsRefType(fanin_info->outputs[position])) {
ref_receiver = true;
}
} }
if (fanin->device() != node.device()) { if (fanin->device() != node.device()) {
cross_device_receivers_.insert(&node); cross_device_receivers_.insert(&node);
} }
} }
if (ref_receiver) {
ref_receivers_.insert(&node);
}
} }
void GraphRewriter::ForwardInputsInternal( void GraphRewriter::ForwardInputsInternal(
...@@ -125,7 +158,7 @@ void GraphRewriter::ForwardInputsInternal( ...@@ -125,7 +158,7 @@ void GraphRewriter::ForwardInputsInternal(
*new_node->add_input() = input; *new_node->add_input() = input;
continue; continue;
} }
const NodeDef* input_node = itr->second; const NodeDef* input_node = itr->second->def;
if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) { if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
ForwardInputsInternal(*input_node, nodes_to_delete, new_node); ForwardInputsInternal(*input_node, nodes_to_delete, new_node);
} else { } else {
......
...@@ -55,6 +55,9 @@ class GraphRewriter { ...@@ -55,6 +55,9 @@ class GraphRewriter {
// device. // device.
bool IsDrivenByAnotherDevice(const NodeDef& node) const; bool IsDrivenByAnotherDevice(const NodeDef& node) const;
// Returns true if the node has input from a stateful op.
bool ReceivesRefValue(const NodeDef& node) const;
private: private:
void RecordConnectivity(const NodeDef& node, void RecordConnectivity(const NodeDef& node,
const std::unordered_set<string>& function_names); const std::unordered_set<string>& function_names);
...@@ -63,11 +66,21 @@ class GraphRewriter { ...@@ -63,11 +66,21 @@ class GraphRewriter {
const std::unordered_set<const NodeDef*>& nodes_to_delete, const std::unordered_set<const NodeDef*>& nodes_to_delete,
NodeDef* new_node); NodeDef* new_node);
std::unordered_map<string, const NodeDef*> nodes_; struct NodeInfo {
const NodeDef* def;
// These are filled in when the NodeInfo is built, but not that they
// may be empty - if the op could not be loaded from the registry.
DataTypeVector inputs;
DataTypeVector outputs;
};
std::unordered_map<string, std::unique_ptr<NodeInfo>> nodes_;
std::unordered_map<string, const NodeDef*> optimized_nodes_; std::unordered_map<string, const NodeDef*> optimized_nodes_;
std::unordered_set<const NodeDef*> control_dependency_drivers_; std::unordered_set<const NodeDef*> control_dependency_drivers_;
std::unordered_set<const NodeDef*> function_neighbors_; std::unordered_set<const NodeDef*> function_neighbors_;
std::unordered_set<const NodeDef*> cross_device_receivers_; std::unordered_set<const NodeDef*> cross_device_receivers_;
std::unordered_set<const NodeDef*> ref_receivers_;
}; };
} // end namespace grappler } // end namespace grappler
......
...@@ -74,20 +74,23 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, ...@@ -74,20 +74,23 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
continue; continue;
} }
// Don't remove nodes that drive control dependencies. // - Don't remove nodes that drive control dependencies.
// Don't remove nodes that are driven by control dependencies either since // - Don't remove nodes that are driven by control dependencies either since
// we can't ensure (yet) that we won't increase the number of control // we can't ensure (yet) that we won't increase the number of control
// dependency edges by deleting them (for example, removing a node driven by // dependency edges by deleting them (for example, removing a node driven
// 10 control edges and driving 10 control edges would result in the // by 10 control edges and driving 10 control edges would result in the
// creation of 100 edges). // creation of 100 edges).
// Don't modify nodes that are connected to functions since that can result // - Don't modify nodes that are connected to functions since that can
// in inlining failures later on. // result in inlining failures later on.
// Don't prune nodes that are driven by another device since these could be // - Don't prune nodes that are driven by another device since these could
// used to reduce cross device communication. // be used to reduce cross device communication.
// - Don't remove nodes that receive reference values, as those can be
// converting references to non-references.
if (!rewriter.DrivesControlDependency(node) && if (!rewriter.DrivesControlDependency(node) &&
!rewriter.IsDrivenByControlDependency(node) && !rewriter.IsDrivenByControlDependency(node) &&
!rewriter.IsConnectedToFunction(node) && !rewriter.IsConnectedToFunction(node) &&
!rewriter.IsDrivenByAnotherDevice(node)) { !rewriter.IsDrivenByAnotherDevice(node) &&
!rewriter.ReceivesRefValue(node)) {
nodes_to_delete.insert(&node); nodes_to_delete.insert(&node);
} }
} }
......
...@@ -199,6 +199,46 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { ...@@ -199,6 +199,46 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
EXPECT_EQ("^c", new_e.input(1)); EXPECT_EQ("^c", new_e.input(1));
} }
TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
// Make graph of Identity(Identity(Identity(Identity(Variable)))).
Output a = ops::Variable(s.WithOpName("a"), {}, DT_INT64);
Output b = ops::Identity(s.WithOpName("b"), a);
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::Identity(s.WithOpName("e"), d);
// Run pruner.
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ModelPruner pruner;
GraphDef output;
Status status = pruner.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
// Get the updated nodes.
ASSERT_EQ(5, output.node_size());
const NodeDef& new_a = output.node(0);
const NodeDef& new_b = output.node(1);
const NodeDef& new_c = output.node(2);
const NodeDef& new_d = output.node(3);
const NodeDef& new_e = output.node(4);
EXPECT_EQ("a", new_a.name());
EXPECT_EQ("b", new_b.name());
EXPECT_EQ("c", new_c.name());
EXPECT_EQ("d", new_d.name());
EXPECT_EQ("e", new_e.name());
// Verify the connections. Identity "b" can't be removed from the chain
// because it is converting a reference input to a non-reference, so c,d,e all
// refer to it as an input.
EXPECT_EQ("a", new_b.input(0));
EXPECT_EQ("b", new_c.input(0));
EXPECT_EQ("b", new_d.input(0));
EXPECT_EQ("b", new_e.input(0));
}
TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) { TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) {
// Build a simple graph with a few trivially prunable ops. // Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册