提交 e5353c94 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Don't prune nodes that have reference inputs.

PiperOrigin-RevId: 163390862
上级 22651083
......@@ -134,6 +134,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
......
......@@ -18,6 +18,8 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
......@@ -26,8 +28,24 @@ namespace tensorflow {
namespace grappler {
GraphRewriter::GraphRewriter(const GrapplerItem& item) {
OpRegistryInterface* op_registry = OpRegistry::Global();
for (auto& node : item.graph.node()) {
nodes_[node.name()] = &node;
NodeInfo* info = new NodeInfo();
info->def = &node;
const OpRegistrationData* op_reg_data = nullptr;
Status s = op_registry->LookUp(node.op(), &op_reg_data);
// TODO(bsteiner): make this not a best-effort lookup and evaluation?
if (s.ok()) {
s = InOutTypesForNode(node, op_reg_data->op_def, &info->inputs,
&info->outputs);
if (!s.ok()) {
info->inputs.clear();
info->outputs.clear();
}
}
nodes_[node.name()].reset(info);
}
std::unordered_set<string> function_names;
......@@ -73,11 +91,16 @@ bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const {
return cross_device_receivers_.find(&node) != cross_device_receivers_.end();
}
bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const {
return ref_receivers_.find(&node) != ref_receivers_.end();
}
void GraphRewriter::RecordConnectivity(
const NodeDef& node, const std::unordered_set<string>& function_names) {
const bool is_function =
function_names.find(node.op()) != function_names.end();
bool ref_receiver = false;
for (const auto& input : node.input()) {
int position = 0;
string input_node_name = ParseNodeName(input, &position);
......@@ -85,7 +108,8 @@ void GraphRewriter::RecordConnectivity(
if (itr == nodes_.end()) {
continue;
}
const NodeDef* fanin = itr->second;
const NodeInfo* fanin_info = itr->second.get();
const NodeDef* fanin = fanin_info->def;
if (position < 0) {
// This is a control edge
control_dependency_drivers_.insert(fanin);
......@@ -97,11 +121,20 @@ void GraphRewriter::RecordConnectivity(
if (is_function) {
function_neighbors_.insert(fanin);
}
if (position < fanin_info->outputs.size() &&
IsRefType(fanin_info->outputs[position])) {
ref_receiver = true;
}
}
if (fanin->device() != node.device()) {
cross_device_receivers_.insert(&node);
}
}
if (ref_receiver) {
ref_receivers_.insert(&node);
}
}
void GraphRewriter::ForwardInputsInternal(
......@@ -125,7 +158,7 @@ void GraphRewriter::ForwardInputsInternal(
*new_node->add_input() = input;
continue;
}
const NodeDef* input_node = itr->second;
const NodeDef* input_node = itr->second->def;
if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
ForwardInputsInternal(*input_node, nodes_to_delete, new_node);
} else {
......
......@@ -55,6 +55,9 @@ class GraphRewriter {
// device.
bool IsDrivenByAnotherDevice(const NodeDef& node) const;
// Returns true if the node has input from a stateful op.
bool ReceivesRefValue(const NodeDef& node) const;
private:
void RecordConnectivity(const NodeDef& node,
const std::unordered_set<string>& function_names);
......@@ -63,11 +66,21 @@ class GraphRewriter {
const std::unordered_set<const NodeDef*>& nodes_to_delete,
NodeDef* new_node);
std::unordered_map<string, const NodeDef*> nodes_;
struct NodeInfo {
const NodeDef* def;
// These are filled in when the NodeInfo is built, but not that they
// may be empty - if the op could not be loaded from the registry.
DataTypeVector inputs;
DataTypeVector outputs;
};
std::unordered_map<string, std::unique_ptr<NodeInfo>> nodes_;
std::unordered_map<string, const NodeDef*> optimized_nodes_;
std::unordered_set<const NodeDef*> control_dependency_drivers_;
std::unordered_set<const NodeDef*> function_neighbors_;
std::unordered_set<const NodeDef*> cross_device_receivers_;
std::unordered_set<const NodeDef*> ref_receivers_;
};
} // end namespace grappler
......
......@@ -74,20 +74,23 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}
// Don't remove nodes that drive control dependencies.
// Don't remove nodes that are driven by control dependencies either since
// we can't ensure (yet) that we won't increase the number of control
// dependency edges by deleting them (for example, removing a node driven by
// 10 control edges and driving 10 control edges would result in the
// creation of 100 edges).
// Don't modify nodes that are connected to functions since that can result
// in inlining failures later on.
// Don't prune nodes that are driven by another device since these could be
// used to reduce cross device communication.
// - Don't remove nodes that drive control dependencies.
// - Don't remove nodes that are driven by control dependencies either since
// we can't ensure (yet) that we won't increase the number of control
// dependency edges by deleting them (for example, removing a node driven
// by 10 control edges and driving 10 control edges would result in the
// creation of 100 edges).
// - Don't modify nodes that are connected to functions since that can
// result in inlining failures later on.
// - Don't prune nodes that are driven by another device since these could
// be used to reduce cross device communication.
// - Don't remove nodes that receive reference values, as those can be
// converting references to non-references.
if (!rewriter.DrivesControlDependency(node) &&
!rewriter.IsDrivenByControlDependency(node) &&
!rewriter.IsConnectedToFunction(node) &&
!rewriter.IsDrivenByAnotherDevice(node)) {
!rewriter.IsDrivenByAnotherDevice(node) &&
!rewriter.ReceivesRefValue(node)) {
nodes_to_delete.insert(&node);
}
}
......
......@@ -199,6 +199,46 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
EXPECT_EQ("^c", new_e.input(1));
}
TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
// Make graph of Identity(Identity(Identity(Identity(Variable)))).
Output a = ops::Variable(s.WithOpName("a"), {}, DT_INT64);
Output b = ops::Identity(s.WithOpName("b"), a);
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::Identity(s.WithOpName("e"), d);
// Run pruner.
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ModelPruner pruner;
GraphDef output;
Status status = pruner.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
// Get the updated nodes.
ASSERT_EQ(5, output.node_size());
const NodeDef& new_a = output.node(0);
const NodeDef& new_b = output.node(1);
const NodeDef& new_c = output.node(2);
const NodeDef& new_d = output.node(3);
const NodeDef& new_e = output.node(4);
EXPECT_EQ("a", new_a.name());
EXPECT_EQ("b", new_b.name());
EXPECT_EQ("c", new_c.name());
EXPECT_EQ("d", new_d.name());
EXPECT_EQ("e", new_e.name());
// Verify the connections. Identity "b" can't be removed from the chain
// because it is converting a reference input to a non-reference, so c,d,e all
// refer to it as an input.
EXPECT_EQ("a", new_b.input(0));
EXPECT_EQ("b", new_c.input(0));
EXPECT_EQ("b", new_d.input(0));
EXPECT_EQ("b", new_e.input(0));
}
TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册