提交 c51eb496 编写于 作者: T Tong Shen 提交者: TensorFlower Gardener

Pass attributes when lowering functional If/While.

PiperOrigin-RevId: 258460978
上级 ca7acecc
......@@ -38,11 +38,12 @@ class CondBuilder {
enum Branch { kElseBranch = 0, kThenBranch = 1 };
// Create a CondBuilder to create the lowered form of `if_op` with then and
// else functions named `then_fn_name` and `else_fn_name` respectively in the
// `graph`. The functions should be available in `flib`.
CondBuilder(Node* if_op, const string& then_fn_name,
const string& else_fn_name, const FunctionLibraryDefinition& flib,
bool keep_node_fetchable, Graph* graph);
// else functions `then_fn` and `else_fn` respectively in the `graph`. The
// functions should be available in `flib`.
CondBuilder(Node* if_op, const NameAttrList& then_fn,
const NameAttrList& else_fn,
const FunctionLibraryDefinition& flib, bool keep_node_fetchable,
Graph* graph);
// Constructs the basic conditional control flow using switch and merge nodes.
Status CreatePivotNodes();
......@@ -103,8 +104,8 @@ class CondBuilder {
NodeBuilder else_call_builder_;
};
CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
const string& else_fn_name,
CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn,
const NameAttrList& else_fn,
const FunctionLibraryDefinition& flib,
bool keep_node_fetchable, Graph* graph)
: if_op_(if_op),
......@@ -113,15 +114,21 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
name_(if_op->name()),
keep_node_fetchable_(keep_node_fetchable),
debug_info_(*if_op_),
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry(),
then_call_builder_(NewName("then"), then_fn.name(), graph->op_registry(),
&debug_info_),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry(),
else_call_builder_(NewName("else"), else_fn.name(), graph->op_registry(),
&debug_info_) {
TF_CHECK_OK(if_op_->input_tensor(0, &pred_));
then_call_builder_.Device(if_op_->requested_device());
then_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
for (const auto& i : then_fn.attr()) {
then_call_builder_.Attr(i.first, i.second);
}
else_call_builder_.Device(if_op_->requested_device());
else_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
for (const auto& i : else_fn.attr()) {
else_call_builder_.Attr(i.first, i.second);
}
}
Status CondBuilder::CreatePivotNodes() {
......@@ -279,7 +286,7 @@ Status RewriteIfNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
return errors::InvalidArgument("Else branch function missing");
}
CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib,
CondBuilder cb(n, then_attr->func(), else_attr->func(), flib,
keep_node_fetchable, g);
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
TF_RETURN_IF_ERROR(cb.AddInputs());
......
......@@ -56,13 +56,12 @@ constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
// consumer
class LowerWhileHelper {
public:
static Status Run(Node* while_op, const string& cond_fn_name,
const string& body_fn_name, int parallel_iterations,
static Status Run(Node* while_op, const NameAttrList& cond_fn,
const NameAttrList& body_fn, int parallel_iterations,
Graph* graph, const FunctionLibraryDefinition& flib,
bool keep_node_fetchable) {
LowerWhileHelper helper(while_op, cond_fn_name, body_fn_name,
parallel_iterations, graph, flib,
keep_node_fetchable);
LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations,
graph, flib, keep_node_fetchable);
return helper.RunInternal();
}
......@@ -70,8 +69,8 @@ class LowerWhileHelper {
// Create a LowerWhileHelper to create the lowering of While op that has cond
// and body functions named `cond_fn_name` and `body_fn_name` respectively in
// the given graph.
LowerWhileHelper(Node* while_op, const string& cond_fn_name,
const string& body_fn_name, int parallel_iterations,
LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
const NameAttrList& body_fn, int parallel_iterations,
Graph* graph, const FunctionLibraryDefinition& flib,
bool keep_node_fetchable);
......@@ -157,8 +156,8 @@ class LowerWhileHelper {
size_t num_loop_inputs_;
};
LowerWhileHelper::LowerWhileHelper(Node* while_op, const string& cond_fn_name,
const string& body_fn_name,
LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
const NameAttrList& body_fn,
int parallel_iterations, Graph* graph,
const FunctionLibraryDefinition& flib,
bool keep_node_fetchable)
......@@ -169,13 +168,19 @@ LowerWhileHelper::LowerWhileHelper(Node* while_op, const string& cond_fn_name,
parallel_iterations_(parallel_iterations),
keep_node_fetchable_(keep_node_fetchable),
debug_info_(*while_op_),
cond_call_builder_(NewName("cond"), cond_fn_name, graph->op_registry(),
cond_call_builder_(NewName("cond"), cond_fn.name(), graph->op_registry(),
&debug_info_),
body_call_builder_(NewName("body"), body_fn_name, graph->op_registry(),
body_call_builder_(NewName("body"), body_fn.name(), graph->op_registry(),
&debug_info_),
num_loop_inputs_(while_op_->num_inputs()) {
cond_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
for (const auto& i : cond_fn.attr()) {
cond_call_builder_.Attr(i.first, i.second);
}
body_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
for (const auto& i : body_fn.attr()) {
body_call_builder_.Attr(i.first, i.second);
}
// We intentionally `resize` instead of `reserve` space in `enter_nodes_`
// because we need to set it's elements out of order in `CreateEnterNodes`.
enter_nodes_.resize(num_loop_inputs_);
......@@ -432,8 +437,8 @@ Status RewriteWhileNode(Node* n, Graph* g,
}
TF_RETURN_IF_ERROR(LowerWhileHelper::Run(
n, cond_attr->func().name(), body_attr->func().name(),
parallel_iterations_attr->i(), g, flib, keep_node_fetchable));
n, cond_attr->func(), body_attr->func(), parallel_iterations_attr->i(), g,
flib, keep_node_fetchable));
g->RemoveNode(n);
return Status::OK();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册