From c51eb496bfed514c5b64438d0565fc2a67cdd1da Mon Sep 17 00:00:00 2001 From: Tong Shen Date: Tue, 16 Jul 2019 16:13:53 -0700 Subject: [PATCH] Pass attributes when lowering functional If/While. PiperOrigin-RevId: 258460978 --- tensorflow/core/common_runtime/lower_if_op.cc | 27 ++++++++++------ .../core/common_runtime/lower_while_op.cc | 31 +++++++++++-------- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index ec37d72faab..2cd89eab756 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -38,11 +38,12 @@ class CondBuilder { enum Branch { kElseBranch = 0, kThenBranch = 1 }; // Create a CondBuilder to create the lowered form of `if_op` with then and - // else functions named `then_fn_name` and `else_fn_name` respectively in the - // `graph`. The functions should be available in `flib`. - CondBuilder(Node* if_op, const string& then_fn_name, - const string& else_fn_name, const FunctionLibraryDefinition& flib, - bool keep_node_fetchable, Graph* graph); + // else functions `then_fn` and `else_fn` respectively in the `graph`. The + // functions should be available in `flib`. + CondBuilder(Node* if_op, const NameAttrList& then_fn, + const NameAttrList& else_fn, + const FunctionLibraryDefinition& flib, bool keep_node_fetchable, + Graph* graph); // Constructs the basic conditional control flow using switch and merge nodes. Status CreatePivotNodes(); @@ -103,8 +104,8 @@ class CondBuilder { NodeBuilder else_call_builder_; }; -CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name, - const string& else_fn_name, +CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn, + const NameAttrList& else_fn, const FunctionLibraryDefinition& flib, bool keep_node_fetchable, Graph* graph) : if_op_(if_op), @@ -113,15 +114,21 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name, name_(if_op->name()), keep_node_fetchable_(keep_node_fetchable), debug_info_(*if_op_), - then_call_builder_(NewName("then"), then_fn_name, graph->op_registry(), + then_call_builder_(NewName("then"), then_fn.name(), graph->op_registry(), &debug_info_), - else_call_builder_(NewName("else"), else_fn_name, graph->op_registry(), + else_call_builder_(NewName("else"), else_fn.name(), graph->op_registry(), &debug_info_) { TF_CHECK_OK(if_op_->input_tensor(0, &pred_)); then_call_builder_.Device(if_op_->requested_device()); then_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true); + for (const auto& i : then_fn.attr()) { + then_call_builder_.Attr(i.first, i.second); + } else_call_builder_.Device(if_op_->requested_device()); else_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true); + for (const auto& i : else_fn.attr()) { + else_call_builder_.Attr(i.first, i.second); + } } Status CondBuilder::CreatePivotNodes() { @@ -279,7 +286,7 @@ Status RewriteIfNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib, return errors::InvalidArgument("Else branch function missing"); } - CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib, + CondBuilder cb(n, then_attr->func(), else_attr->func(), flib, keep_node_fetchable, g); TF_RETURN_IF_ERROR(cb.CreatePivotNodes()); TF_RETURN_IF_ERROR(cb.AddInputs()); diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc index 65f4caf7573..c1c5e510bd0 100644 --- a/tensorflow/core/common_runtime/lower_while_op.cc +++ b/tensorflow/core/common_runtime/lower_while_op.cc @@ -56,13 +56,12 @@ constexpr const char* const kLowerAsMultiDeviceFunctionAttr = // consumer class LowerWhileHelper { public: - static Status Run(Node* while_op, const string& cond_fn_name, - const string& body_fn_name, int parallel_iterations, + static Status Run(Node* while_op, const NameAttrList& cond_fn, + const NameAttrList& body_fn, int parallel_iterations, Graph* graph, const FunctionLibraryDefinition& flib, bool keep_node_fetchable) { - LowerWhileHelper helper(while_op, cond_fn_name, body_fn_name, - parallel_iterations, graph, flib, - keep_node_fetchable); + LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations, + graph, flib, keep_node_fetchable); return helper.RunInternal(); } @@ -70,8 +69,8 @@ class LowerWhileHelper { // Create a LowerWhileHelper to create the lowering of While op that has cond // and body functions named `cond_fn_name` and `body_fn_name` respectively in // the given graph. - LowerWhileHelper(Node* while_op, const string& cond_fn_name, - const string& body_fn_name, int parallel_iterations, + LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn, + const NameAttrList& body_fn, int parallel_iterations, Graph* graph, const FunctionLibraryDefinition& flib, bool keep_node_fetchable); @@ -157,8 +156,8 @@ class LowerWhileHelper { size_t num_loop_inputs_; }; -LowerWhileHelper::LowerWhileHelper(Node* while_op, const string& cond_fn_name, - const string& body_fn_name, +LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn, + const NameAttrList& body_fn, int parallel_iterations, Graph* graph, const FunctionLibraryDefinition& flib, bool keep_node_fetchable) @@ -169,13 +168,19 @@ LowerWhileHelper::LowerWhileHelper(Node* while_op, const string& cond_fn_name, parallel_iterations_(parallel_iterations), keep_node_fetchable_(keep_node_fetchable), debug_info_(*while_op_), - cond_call_builder_(NewName("cond"), cond_fn_name, graph->op_registry(), + cond_call_builder_(NewName("cond"), cond_fn.name(), graph->op_registry(), &debug_info_), - body_call_builder_(NewName("body"), body_fn_name, graph->op_registry(), + body_call_builder_(NewName("body"), body_fn.name(), graph->op_registry(), &debug_info_), num_loop_inputs_(while_op_->num_inputs()) { cond_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true); + for (const auto& i : cond_fn.attr()) { + cond_call_builder_.Attr(i.first, i.second); + } body_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true); + for (const auto& i : body_fn.attr()) { + body_call_builder_.Attr(i.first, i.second); + } // We intentionally `resize` instead of `reserve` space in `enter_nodes_` // because we need to set it's elements out of order in `CreateEnterNodes`. enter_nodes_.resize(num_loop_inputs_); @@ -432,8 +437,8 @@ Status RewriteWhileNode(Node* n, Graph* g, } TF_RETURN_IF_ERROR(LowerWhileHelper::Run( - n, cond_attr->func().name(), body_attr->func().name(), - parallel_iterations_attr->i(), g, flib, keep_node_fetchable)); + n, cond_attr->func(), body_attr->func(), parallel_iterations_attr->i(), g, + flib, keep_node_fetchable)); g->RemoveNode(n); return Status::OK(); -- GitLab