Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
ca7acecc
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ca7acecc
编写于
7月 16, 2019
作者:
E
Eugene Zhulenev
提交者:
TensorFlower Gardener
7月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do not force DT_RESOURCE return node to be on the source node device
PiperOrigin-RevId: 258454960
上级
d38a8fe0
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
140 addition
and
68 deletion
+140
-68
tensorflow/core/common_runtime/process_function_library_runtime.cc
...w/core/common_runtime/process_function_library_runtime.cc
+73
-59
tensorflow/core/framework/function.h
tensorflow/core/framework/function.h
+25
-9
tensorflow/python/eager/function_test.py
tensorflow/python/eager/function_test.py
+42
-0
未找到文件。
tensorflow/core/common_runtime/process_function_library_runtime.cc
浏览文件 @
ca7acecc
...
...
@@ -368,12 +368,18 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
for
(
Node
*
node
:
ret_nodes
)
{
if
(
output_devices
.
empty
())
{
VLOG
(
3
)
<<
"Trying to determine device for node "
<<
node
->
name
();
DataType
dtype
;
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
node
->
attrs
(),
"T"
,
&
dtype
));
VLOG
(
3
)
<<
"Trying to determine device for node "
<<
node
->
name
()
<<
"[T="
<<
DataTypeString
(
dtype
)
<<
"]"
;
// If output_devices are empty, the node producing retval
// must have explicitly assigned device or a colocation constraint
// to a node with explicitly assigned device.
for
(
const
auto
&
it
:
node
->
in_edges
())
{
if
(
!
it
->
IsControlEdge
())
{
if
(
it
->
IsControlEdge
())
continue
;
Node
*
src_node
=
it
->
src
();
const
string
*
src_device
=
AssignedOrRequestedDeviceName
(
*
src_node
);
string
colocation_group
=
""
;
...
...
@@ -395,19 +401,26 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
<<
" colo group: "
<<
colocation_group
;
}
// If resource is produced by a function call node, we can't trust
// source node device assignment, because multi-device functions can
// return resource placed on multiple devices. In such case we leave
// retval device assignment empty, and rely on placer to infer correct
// assignment based on actual output device.
const
bool
can_use_src_node_device
=
!
(
dtype
==
DT_RESOURCE
&&
IsFunctionCall
(
*
lib_def_
,
*
src_node
));
if
(
!
colocation_group
.
empty
())
{
AttrValue
::
ListValue
colo_attr
;
colo_attr
.
add_s
(
colocation_group
);
std
::
vector
<
string
>
colo_slice
=
{
colocation_group
};
node
->
AddAttr
(
kColocationAttrName
,
colo_slice
);
}
else
if
(
!
src_device
->
empty
()
)
{
}
else
if
(
!
src_device
->
empty
()
&&
can_use_src_node_device
)
{
// src_device can be a partially specified device. Find the
// matching device in the device_set.
DeviceNameUtils
::
ParsedName
parsed
;
if
(
!
DeviceNameUtils
::
ParseFullName
(
*
src_device
,
&
parsed
))
{
return
errors
::
InvalidArgument
(
"Failed to parse explicit device specification "
,
*
src_device
);
"Failed to parse explicit device specification "
,
*
src_device
);
}
std
::
vector
<
Device
*>
matching_devices
;
device_set
.
FindMatchingDevices
(
parsed
,
&
matching_devices
);
...
...
@@ -434,11 +447,12 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
"device. Matched devices are "
,
devices
);
}
VLOG
(
3
)
<<
"Setting output device to "
<<
matching_devices
[
0
]
->
name
()
<<
" for node "
<<
node
->
DebugString
();
VLOG
(
3
)
<<
"Setting output device to "
<<
matching_devices
[
0
]
->
name
()
<<
" for node "
<<
SummarizeNode
(
*
node
);
node
->
set_assigned_device_name
(
matching_devices
[
0
]
->
name
());
}
}
else
if
(
!
src_device
->
empty
()
&&
!
can_use_src_node_device
)
{
VLOG
(
3
)
<<
"Did not set device for a resource output node "
<<
SummarizeNode
(
*
node
);
}
}
}
else
{
...
...
tensorflow/core/framework/function.h
浏览文件 @
ca7acecc
...
...
@@ -536,15 +536,31 @@ class FunctionLibraryRuntime {
std
::
vector
<
string
>
input_devices
;
// For multi-device functions, a vector of canonical device names for
// function's outputs. The device of resource outputs should be the CPU
// device, not the device backing the resource.
// If specified, must have the same length as the number of function
// outputs.
// If not specified, output devices are picked automatically. If operations
// producing the output tensors have explicit device specification, they
// will be respected. These device specifications must identify a unique
// device, i.e. a general specification like "job:foo" matching multiple
// devices will result in an error.
// function's outputs.
//
// (a) If specified (must have the same length as number of outputs):
//
// Specified devices will be assigned to Retval nodes inserted into the
// function body graph in place of function outputs. It is allowed to
// specify output device as empty string, in this case Retval device
// assignment will be inferred later when function graph will be placed
// before partitioning (this is required for resource outputs). Placer will
// respect colocation constraints.
//
// (b) If not specified:
//
// Function runtime will infer Retval device by following input edges, until
// it will reach a node with a device specification. This device
// specification must identify a unique device, i.e. a general specification
// like "job:foo" matching multiple devices will result in an error.
//
// IMPORTANT: Resource outputs
//
// Multi device functions might return resources on a devices different from
// the function call device. If output device is not specified for the
// resource output, and node producing that resource is a function call,
// runtime will leave device specification empty and will rely on Placer to
// infer correct device.
std
::
vector
<
string
>
output_devices
;
// This interface is EXPERIMENTAL and subject to change.
...
...
tensorflow/python/eager/function_test.py
浏览文件 @
ca7acecc
...
...
@@ -2948,6 +2948,48 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
r1
.
numpy
(),
6.0
)
self
.
assertRegexpMatches
(
r1
.
backing_device
,
'CPU'
)
@
test_util
.
run_gpu_only
def
testReturnResourceFromNestedFunctionCall
(
self
):
"""Test returning GPU resource from noinline function call placed on CPU.
When inferring output devices for the return value, do not set a device for
returns of DT_RESOURCE data type based on the device assignment of the node
that produced that resource. As an example function call placed on CPU can
return resources on GPU.
"""
with
ops
.
device
(
'/device:GPU:0'
):
g1
=
resource_variable_ops
.
ResourceVariable
(
3.0
)
@
function
.
defun_with_attributes
(
attributes
=
{
'_noinline'
:
True
})
def
inner
(
resource1
):
resource1
.
assign_add
(
2.0
)
return
resource1
*
2
,
resource1
.
handle
@
function
.
defun
def
outer
(
resource1
):
with
ops
.
device
(
'/device:CPU:0'
):
r1
,
res1
=
inner
(
resource1
)
return
r1
,
res1
r1
,
res1
=
outer
(
g1
)
self
.
assertEqual
(
r1
.
numpy
(),
10.0
)
self
.
assertRegexpMatches
(
r1
.
backing_device
,
'CPU'
)
def
check_handle
(
handle
,
expected_value
):
self
.
assertRegexpMatches
(
handle
.
backing_device
,
'CPU'
)
tensor
=
gen_resource_variable_ops
.
read_variable_op
(
handle
,
dtypes
.
float32
)
self
.
assertEqual
(
tensor
.
numpy
(),
expected_value
)
# Check that handles returned from functions are on CPU and an op using
# the resource handle is correctly placed on the device backing the
# resource.
check_handle
(
res1
,
5.0
)
@
test_util
.
run_gpu_only
def
testComplexInputOutputDevicePattern
(
self
):
"""Tests input/output mapping logic in partitioning."""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录