提交 7d8ab5df 编写于 作者: S Sanjoy Das 提交者: TensorFlower Gardener

Pass in the correct value for allow_resource_ops_in_called_functions

PiperOrigin-RevId: 241430735
上级 22413d6d
......@@ -380,7 +380,7 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
XlaOpRegistry::AutoclusteringPolicy::kAlways;
RecursiveCompilabilityChecker::OperationFilter op_filter;
op_filter.allow_non_resource_var_resource_ops =
op_filter.allow_resource_ops_in_called_functions =
registration.compile_all_resource_ops;
op_filter.allow_non_resource_var_resource_ops =
registration.compile_all_resource_ops;
......
......@@ -277,6 +277,61 @@ TEST(XlaCompilationTest, FunctionCalls) {
EXPECT_TRUE(clusters.find("E") == clusters.cend());
}
TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
FunctionDef compilable = FunctionDefHelper::Define(
"FnWithResourceOp", {"var:resource", "val:float"}, {"retval:float"}, {},
{{{"assign_op"},
"AssignVariableOp",
{"var", "val"},
{{"dtype", DT_FLOAT}}},
{{"retval"}, "Identity", {"val"}, {{"T", DT_FLOAT}}, {"assign_op"}}});
FunctionDefLibrary flib;
*flib.add_function() = compilable;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
std::unique_ptr<Graph> graph(new Graph(&flib_def));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
Node* resource =
ops::SourceOp("VarHandleOp", builder.opts()
.WithName("varhandle")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("shape", TensorShape({})));
Tensor const_tensor(DT_FLOAT, TensorShape({}));
const_tensor.scalar<float>()() = 42.0f;
Node* value = ops::SourceOp("Const", builder.opts()
.WithName("const")
.WithAttr("value", const_tensor)
.WithAttr("dtype", DT_FLOAT));
Node* call = ops::BinaryOp("FnWithResourceOp", resource, value,
builder.opts().WithName("A"));
Node* tanh0 = ops::UnaryOp("Tanh", call, builder.opts().WithName("tanh0"));
Node* tanh1 = ops::UnaryOp("Tanh", tanh0, builder.opts().WithName("tanh1"));
ops::UnaryOp("Tanh", tanh1, builder.opts().WithName("tanh2"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
testing::FindNodeByName(graph.get(), "A")
->set_assigned_device_name(xla_cpu_device);
testing::FindNodeByName(graph.get(), "tanh0")
->set_assigned_device_name(xla_cpu_device);
testing::FindNodeByName(graph.get(), "tanh1")
->set_assigned_device_name(xla_cpu_device);
testing::FindNodeByName(graph.get(), "tanh2")
->set_assigned_device_name(xla_cpu_device);
TF_ASSERT_OK(
MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
auto clusters = GetClusters(*graph);
EXPECT_NE(clusters["A"], "");
}
// Metadata-only operators such as Shape/Rank/Size may not be the root of a
// cluster. This is partially to work around b/26800664, and partially because
// we should probably prefer to compile metadata operators with their producers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册