提交 ca7acecc 编写于 作者: E Eugene Zhulenev 提交者: TensorFlower Gardener

Do not force DT_RESOURCE return node to be on the source node device

PiperOrigin-RevId: 258454960
上级 d38a8fe0
......@@ -368,77 +368,91 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
for (Node* node : ret_nodes) {
if (output_devices.empty()) {
VLOG(3) << "Trying to determine device for node " << node->name();
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
VLOG(3) << "Trying to determine device for node " << node->name()
<< "[T=" << DataTypeString(dtype) << "]";
// If output_devices are empty, the node producing retval
// must have explicitly assigned device or a colocation constraint
// to a node with explicitly assigned device.
for (const auto& it : node->in_edges()) {
if (!it->IsControlEdge()) {
Node* src_node = it->src();
const string* src_device = AssignedOrRequestedDeviceName(*src_node);
string colocation_group = "";
if (it->IsControlEdge()) continue;
Node* src_node = it->src();
const string* src_device = AssignedOrRequestedDeviceName(*src_node);
string colocation_group = "";
GetColocationGroup(src_node, &colocation_group);
VLOG(3) << "Considering src: " << src_node->name()
<< " src_device: " << *src_device
<< " colo group: " << colocation_group;
while (src_device->empty() && colocation_group.empty() &&
src_node->IsIdentity()) {
// Only follows the real data input of Identity, not control edges.
Node* input_node;
TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
src_node = input_node;
src_device = AssignedOrRequestedDeviceName(*src_node);
GetColocationGroup(src_node, &colocation_group);
VLOG(3) << "Considering src: " << src_node->name()
<< " src_device: " << *src_device
<< " colo group: " << colocation_group;
while (src_device->empty() && colocation_group.empty() &&
src_node->IsIdentity()) {
// Only follows the real data input of Identity, not control edges.
Node* input_node;
TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
src_node = input_node;
src_device = AssignedOrRequestedDeviceName(*src_node);
GetColocationGroup(src_node, &colocation_group);
VLOG(3) << "Considering src: " << src_node->name()
<< " src_device: " << *src_device
<< " colo group: " << colocation_group;
}
}
if (!colocation_group.empty()) {
AttrValue::ListValue colo_attr;
colo_attr.add_s(colocation_group);
std::vector<string> colo_slice = {colocation_group};
node->AddAttr(kColocationAttrName, colo_slice);
} else if (!src_device->empty()) {
// src_device can be a partially specified device. Find the
// matching device in the device_set.
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
return errors::InvalidArgument(
"Failed to parse explicit device specification ",
*src_device);
// If resource is produced by a function call node, we can't trust
// source node device assignment, because multi-device functions can
// return resource placed on multiple devices. In such case we leave
// retval device assignment empty, and rely on placer to infer correct
// assignment based on actual output device.
const bool can_use_src_node_device =
!(dtype == DT_RESOURCE && IsFunctionCall(*lib_def_, *src_node));
if (!colocation_group.empty()) {
AttrValue::ListValue colo_attr;
colo_attr.add_s(colocation_group);
std::vector<string> colo_slice = {colocation_group};
node->AddAttr(kColocationAttrName, colo_slice);
} else if (!src_device->empty() && can_use_src_node_device) {
// src_device can be a partially specified device. Find the
// matching device in the device_set.
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
return errors::InvalidArgument(
"Failed to parse explicit device specification ", *src_device);
}
std::vector<Device*> matching_devices;
device_set.FindMatchingDevices(parsed, &matching_devices);
if (matching_devices.empty()) {
return errors::InvalidArgument(
"Unable to find any devices for spec ", *src_device);
} else if (matching_devices.size() != 1) {
// Convert a vector of devices to a string.
// Using absl::StrJoin did not work in Android builds.
string devices = "[";
for (Device* device : matching_devices) {
devices.append(device->name());
devices.append(", ");
}
std::vector<Device*> matching_devices;
device_set.FindMatchingDevices(parsed, &matching_devices);
if (matching_devices.empty()) {
return errors::InvalidArgument(
"Unable to find any devices for spec ", *src_device);
} else if (matching_devices.size() != 1) {
// Convert a vector of devices to a string.
// Using absl::StrJoin did not work in Android builds.
string devices = "[";
for (Device* device : matching_devices) {
devices.append(device->name());
devices.append(", ");
}
if (devices.size() > 2) {
devices.resize(devices.size() - 2);
}
devices.append("]");
return errors::InvalidArgument(
"When FunctionLibraryRuntime::Options.output_devices are "
"not specified for a multi-device function, the device "
"specification on the output node must match exactly one "
"device. Matched devices are ",
devices);
if (devices.size() > 2) {
devices.resize(devices.size() - 2);
}
VLOG(3) << "Setting output device to "
<< matching_devices[0]->name() << " for node "
<< node->DebugString();
node->set_assigned_device_name(matching_devices[0]->name());
devices.append("]");
return errors::InvalidArgument(
"When FunctionLibraryRuntime::Options.output_devices are "
"not specified for a multi-device function, the device "
"specification on the output node must match exactly one "
"device. Matched devices are ",
devices);
}
VLOG(3) << "Setting output device to " << matching_devices[0]->name()
<< " for node " << SummarizeNode(*node);
node->set_assigned_device_name(matching_devices[0]->name());
} else if (!src_device->empty() && !can_use_src_node_device) {
VLOG(3) << "Did not set device for a resource output node "
<< SummarizeNode(*node);
}
}
} else {
......
......@@ -536,15 +536,31 @@ class FunctionLibraryRuntime {
std::vector<string> input_devices;
// For multi-device functions, a vector of canonical device names for
// function's outputs. The device of resource outputs should be the CPU
// device, not the device backing the resource.
// If specified, must have the same length as the number of function
// outputs.
// If not specified, output devices are picked automatically. If operations
// producing the output tensors have explicit device specification, they
// will be respected. These device specifications must identify a unique
// device, i.e. a general specification like "job:foo" matching multiple
// devices will result in an error.
// function's outputs.
//
// (a) If specified (must have the same length as number of outputs):
//
// Specified devices will be assigned to Retval nodes inserted into the
// function body graph in place of function outputs. It is allowed to
// specify output device as empty string, in this case Retval device
// assignment will be inferred later when function graph will be placed
// before partitioning (this is required for resource outputs). Placer will
// respect colocation constraints.
//
// (b) If not specified:
//
// Function runtime will infer Retval device by following input edges, until
// it will reach a node with a device specification. This device
// specification must identify a unique device, i.e. a general specification
// like "job:foo" matching multiple devices will result in an error.
//
// IMPORTANT: Resource outputs
//
// Multi device functions might return resources on a devices different from
// the function call device. If output device is not specified for the
// resource output, and node producing that resource is a function call,
// runtime will leave device specification empty and will rely on Placer to
// infer correct device.
std::vector<string> output_devices;
// This interface is EXPERIMENTAL and subject to change.
......
......@@ -2948,6 +2948,48 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
self.assertEqual(r1.numpy(), 6.0)
self.assertRegexpMatches(r1.backing_device, 'CPU')
@test_util.run_gpu_only
def testReturnResourceFromNestedFunctionCall(self):
"""Test returning GPU resource from noinline function call placed on CPU.
When inferring output devices for the return value, do not set a device for
returns of DT_RESOURCE data type based on the device assignment of the node
that produced that resource. As an example function call placed on CPU can
return resources on GPU.
"""
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
@function.defun_with_attributes(attributes={
'_noinline': True
})
def inner(resource1):
resource1.assign_add(2.0)
return resource1 * 2, resource1.handle
@function.defun
def outer(resource1):
with ops.device('/device:CPU:0'):
r1, res1 = inner(resource1)
return r1, res1
r1, res1 = outer(g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegexpMatches(r1.backing_device, 'CPU')
def check_handle(handle, expected_value):
self.assertRegexpMatches(handle.backing_device, 'CPU')
tensor = gen_resource_variable_ops.read_variable_op(
handle, dtypes.float32)
self.assertEqual(tensor.numpy(), expected_value)
# Check that handles returned from functions are on CPU and an op using
# the resource handle is correctly placed on the device backing the
# resource.
check_handle(res1, 5.0)
@test_util.run_gpu_only
def testComplexInputOutputDevicePattern(self):
"""Tests input/output mapping logic in partitioning."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册