提交 10cab63f 编写于 作者: T Tong Shen 提交者: TensorFlower Gardener

Outside compilation in "If" and "While".

PiperOrigin-RevId: 224933587
上级 0d822c01
......@@ -515,6 +515,7 @@ cc_library(
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
......@@ -613,6 +614,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:scope",
......@@ -625,6 +627,7 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
......
......@@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name, const string& xla_cluster_name,
const NameAttrList& func_name_attrs, const string& new_func_name,
const string& host_graph_func_name,
const std::map<string, int>& host_compute_core,
FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph,
std::vector<string>* shape_inference_graphs, bool* has_outside_compilation);
FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
bool* has_outside_compilation);
// Rewrites XLA computation in `clusters` to replace outside compilation nodes
// with XlaHostCompute, and moves those outside compilations into `g`. If shapes
......
......@@ -41,8 +41,7 @@ Status MakeXlaCompilerArgumentsFromInputs(
*has_uninitialized_vars = false;
*has_tensor_arrays = false;
for (int i = 0; i < ctx->num_inputs(); ++i) {
VLOG(2) << " Input " << i
<< " type: " << DataTypeString(ctx->input_type(i))
VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i))
<< " shape: " << ctx->InputShape(i).DebugString();
XlaCompiler::Argument& arg = (*args)[i];
DataType type = ctx->input_type(i);
......@@ -233,13 +232,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
xla::ShapeUtil::HumanString(body.xla_output_shape)));
xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::PRED, {})});
xla::Shape expected_cond_output_shape_without_side_effect =
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::PRED, {})});
xla::Shape expected_cond_output_shape_with_side_effect =
xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}),
xla::ShapeUtil::MakeTokenShape()});
OP_REQUIRES(ctx,
xla::ShapeUtil::Compatible(cond.xla_output_shape,
expected_cond_output_shape),
xla::ShapeUtil::Compatible(
cond.xla_output_shape,
expected_cond_output_shape_without_side_effect) ||
xla::ShapeUtil::Compatible(
cond.xla_output_shape,
expected_cond_output_shape_with_side_effect),
errors::InvalidArgument(
"Output shape of loop condition should be (pred[]), got: ",
"Output shape of loop condition should be (pred[]) or "
"(pred[], token[]), got: ",
xla::ShapeUtil::HumanString(cond.xla_output_shape)));
int num_inputs = body.input_mapping.size();
......
......@@ -24,6 +24,8 @@ const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer";
std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) {
std::set<std::string> results;
Node* first_side_effecting_node_on_path = nullptr;
......
......@@ -35,6 +35,9 @@ extern const char kXlaTokenInputNodesAttrName[];
// node has side-effect dependency on current graph's token input.
extern const char kXlaTokenArgNodeName[];
// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions.
extern const char kXlaHasHostTransferAttrName[];
// Calculates side-effect dependencies for the graph's token output.
// Returns a set of node names representing these dependencies.
std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g);
......
......@@ -557,6 +557,12 @@ bool HasAssociatedFunction(const NodeDef& node_def,
return true;
}
if (node_def.op() == "XlaHostCompute") {
// XlaHostCompute has "shape_inference_graph" func attr, but that's not
// related to graph execution.
return false;
}
for (const auto& iter : node_def.attr()) {
if (iter.second.has_func()) {
return true;
......@@ -578,6 +584,9 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
// This is a SymbolicGradient op.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
} else if (node.type_string() == "XlaHostCompute") {
// XlaHostCompute has "shape_inference_graph" func attr, but that's not
// related to graph execution.
} else {
// Collect all function attrs for the node.
for (auto& iter : node.attrs()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册