提交 f0fd2bed 编写于 作者: E Eugene Zhulenev 提交者: TensorFlower Gardener

Do not force PartitionedCall DT_RESOURCE outputs to be on CPU device

PiperOrigin-RevId: 258422018
上级 5819055a
......@@ -569,12 +569,12 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
int index = 0;
VLOG(3) << "Requested input devices:";
for (const string& device : options.input_devices) {
VLOG(3) << " " << device << " for input at index " << index++;
VLOG(3) << " [input " << index++ << "] " << device;
}
index = 0;
VLOG(3) << "Requested output devices:";
for (const string& device : options.output_devices) {
VLOG(3) << " " << device << " for output at index " << index++;
VLOG(3) << " [output " << index++ << "] " << device;
}
}
......
......@@ -57,6 +57,9 @@ namespace {
constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
// Do not specialize functions marked with '_nospecialize' attribute.
constexpr const char* const kNoSpecializeAttr = "_nospecialize";
// Mark functions that were created as a result of function specialization.
constexpr const char* const kGrapplerSpecializedFuncAttr =
"_GrapplerSpecializedFunc";
......@@ -140,6 +143,13 @@ class FakeDevice : public Device {
// it (see details in MetaOptimizer). Also we can push known constant inputs
// into the function body, and remove unused outputs/inputs.
bool MarkedNoSpecialize(const FunctionDef& fdef) {
const auto attr = AttrSlice(&fdef.attr());
bool nospecialize = false;
return GetNodeAttr(attr, kNoSpecializeAttr, &nospecialize).ok() &&
nospecialize;
}
// Specialized function instantiation type parameters, body parameters, and
// const inputs.
struct FunctionSpecializationSignature {
......@@ -1388,13 +1398,15 @@ Status FunctionOptimizer::RunFunctionOptimizerPass(
// Specialize it to its instantiation context if it has something worth
// specializing.
bool specialization_worthy = IsParametrized(*func) ||
HasTrulyConstInputs(node, ctx) ||
HasUnusedOutputs(node, *func, ctx);
// Do not specialize if function has custom gradient.
const bool specialization_worthy = IsParametrized(*func) ||
HasTrulyConstInputs(node, ctx) ||
HasUnusedOutputs(node, *func, ctx);
// Do not specialize if function has custom gradient or marked nospecialize.
const string grad_func = ctx.function_library().FindGradient(func_name);
const bool no_specialize = !grad_func.empty() || MarkedNoSpecialize(*func);
if (grad_func.empty() && specialization_worthy) {
if (specialization_worthy && !no_specialize) {
// TODO(ezhulenev): Specialize function call if input has a known shape.
// Specialize function body for its instantiation attributes and inputs.
Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
......
......@@ -145,7 +145,13 @@ Status PartitionedCallOp::FillOutputDevices(
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
for (DataType dtype : dtypes) {
if (MTypeFromDType(dtype) == HOST_MEMORY) {
if (dtype == DT_RESOURCE) {
// Resource memory type is HOST_MEMORY, however the actual resource
// might be allocated on a device. We leave output device for resource
// outputs empty, and rely on a Placer and colocation constraints to
// infer correct placement for the function output.
opts->output_devices.push_back("");
} else if (MTypeFromDType(dtype) == HOST_MEMORY) {
opts->output_devices.push_back(cpu_device.name());
} else {
opts->output_devices.push_back(opts->target);
......
......@@ -2918,6 +2918,36 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
check_handle(res1, 3.0)
check_handle(res2, 2.0)
@test_util.run_gpu_only
def testPassResourceThroughNestedFunctionCall(self):
"""Test passing GPU resource to noinline function call placed on CPU.
PartitionedCallOp must not enforce any particular device assignment for the
resource output. Inner function marked as `_nospecialize`, so Grappler would
not prune unused function output.
"""
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
@function.defun_with_attributes(attributes={
'_noinline': True,
'_nospecialize': True
})
def inner(resource1):
return resource1 * 2, resource1.handle
@function.defun
def outer(resource1):
with ops.device('/device:CPU:0'):
r1, _ = inner(resource1)
return r1
r1 = outer(g1)
self.assertEqual(r1.numpy(), 6.0)
self.assertRegexpMatches(r1.backing_device, 'CPU')
@test_util.run_gpu_only
def testComplexInputOutputDevicePattern(self):
"""Tests input/output mapping logic in partitioning."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册