Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4abef501
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4abef501
编写于
4月 16, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code refine
上级
2aaa75ec
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
69 addition
and
50 deletion
+69
-50
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+31
-27
paddle/fluid/framework/details/broadcast_op_handle.h
paddle/fluid/framework/details/broadcast_op_handle.h
+6
-0
paddle/fluid/framework/details/gather_op_handle.cc
paddle/fluid/framework/details/gather_op_handle.cc
+27
-23
paddle/fluid/framework/details/gather_op_handle.h
paddle/fluid/framework/details/gather_op_handle.h
+5
-0
未找到文件。
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
4abef501
...
...
@@ -34,40 +34,21 @@ BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
:
local_scopes_
(
local_scopes
),
places_
(
places
)
{}
void
BroadcastOpHandle
::
RunImpl
()
{
// the input may have dummy var.
std
::
vector
<
VarHandle
*>
in_var_handle
;
for
(
auto
*
in
:
inputs_
)
{
auto
*
out_handle
=
dynamic_cast
<
VarHandle
*>
(
in
);
if
(
out_handle
)
{
in_var_handle
.
push_back
(
out_handle
);
}
}
// the input and output may have dummy var.
std
::
vector
<
VarHandle
*>
in_var_handle
=
GetValidVarHandles
(
inputs_
);
std
::
vector
<
VarHandle
*>
out_var_handles
=
GetValidVarHandles
(
outputs_
);
PADDLE_ENFORCE_EQ
(
in_var_handle
.
size
(),
1
,
"The number of input should be one."
);
// the output may have dummy var.
std
::
vector
<
VarHandle
*>
out_var_handles
;
for
(
auto
*
out
:
outputs_
)
{
auto
*
out_handle
=
dynamic_cast
<
VarHandle
*>
(
out
);
if
(
out_handle
)
{
out_var_handles
.
push_back
(
out_handle
);
}
}
PADDLE_ENFORCE_EQ
(
out_var_handles
.
size
(),
places_
.
size
(),
"The number of output should equal to the number of places."
);
// Wait input done, this Wait is asynchronous operation
auto
&
in_place
=
in_var_handle
[
0
]
->
place_
;
if
(
in_var_handle
[
0
]
->
generated_op_
)
{
for
(
auto
*
out
:
out_var_handles
)
{
auto
&
out_p
=
out
->
place_
;
in_var_handle
[
0
]
->
generated_op_
->
Wait
(
dev_ctxes_
[
out_p
]);
}
}
// Wait input done, this Wait is asynchronous operationplatform::Place
// &in_place;
WaitEvents
(
out_var_handles
,
in_var_handle
);
//
auto
in_place
=
in_var_handle
[
0
]
->
place_
;
auto
in_scope_idx
=
in_var_handle
[
0
]
->
scope_idx_
;
auto
in_var
=
local_scopes_
.
at
(
in_scope_idx
)
->
FindVar
(
in_var_handle
[
0
]
->
name_
);
...
...
@@ -107,6 +88,29 @@ void BroadcastOpHandle::RunImpl() {
}
}
void
BroadcastOpHandle
::
WaitEvents
(
const
std
::
vector
<
VarHandle
*>
&
out_var_handles
,
const
std
::
vector
<
VarHandle
*>
&
in_var_handle
)
{
if
(
in_var_handle
[
0
]
->
generated_op_
)
{
for
(
auto
*
out
:
out_var_handles
)
{
auto
&
out_p
=
out
->
place_
;
in_var_handle
[
0
]
->
generated_op_
->
Wait
(
dev_ctxes_
[
out_p
]);
}
}
}
std
::
vector
<
VarHandle
*>
BroadcastOpHandle
::
GetValidVarHandles
(
const
std
::
vector
<
VarHandleBase
*>
&
inputs
)
{
std
::
vector
<
VarHandle
*>
in_var_handle
;
for
(
auto
*
in
:
inputs
)
{
auto
*
out_handle
=
dynamic_cast
<
VarHandle
*>
(
in
);
if
(
out_handle
)
{
in_var_handle
.
push_back
(
out_handle
);
}
}
return
in_var_handle
;
}
std
::
string
BroadcastOpHandle
::
Name
()
const
{
return
"broadcast"
;
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/broadcast_op_handle.h
浏览文件 @
4abef501
...
...
@@ -41,6 +41,12 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
std
::
vector
<
VarHandle
*>
GetValidVarHandles
(
const
std
::
vector
<
VarHandleBase
*>
&
inputs
);
void
WaitEvents
(
const
std
::
vector
<
VarHandle
*>
&
out_var_handles
,
const
std
::
vector
<
VarHandle
*>
&
in_var_handle
);
};
}
// namespace details
...
...
paddle/fluid/framework/details/gather_op_handle.cc
浏览文件 @
4abef501
...
...
@@ -23,26 +23,13 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
:
local_scopes_
(
local_scopes
),
places_
(
places
)
{}
void
GatherOpHandle
::
RunImpl
()
{
// the input may have dummy var.
std
::
vector
<
VarHandle
*>
in_var_handles
;
for
(
auto
*
in
:
inputs_
)
{
auto
*
in_handle
=
dynamic_cast
<
VarHandle
*>
(
in
);
if
(
in_handle
)
{
in_var_handles
.
push_back
(
in_handle
);
}
}
// the input and output may have dummy var.
std
::
vector
<
VarHandle
*>
in_var_handles
=
GetValidVarHandles
(
inputs_
);
std
::
vector
<
VarHandle
*>
out_var_handles
=
GetValidVarHandles
(
outputs_
);
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
places_
.
size
(),
"The number of output should equal to the number of places."
);
// the output may have dummy var.
std
::
vector
<
VarHandle
*>
out_var_handles
;
for
(
auto
*
out
:
outputs_
)
{
auto
*
out_handle
=
dynamic_cast
<
VarHandle
*>
(
out
);
if
(
out_handle
)
{
out_var_handles
.
push_back
(
out_handle
);
}
}
PADDLE_ENFORCE_EQ
(
out_var_handles
.
size
(),
1
,
"The number of output should be one."
);
...
...
@@ -58,11 +45,7 @@ void GatherOpHandle::RunImpl() {
"The place of input and output should be the same."
);
// Wait input done, this Wait is asynchronous operation
for
(
auto
*
in
:
in_var_handles
)
{
if
(
in
->
generated_op_
)
{
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
in
->
place_
]);
}
}
WaitEvents
(
in_var_handles
);
std
::
vector
<
int64_t
>
out_rows
;
std
::
vector
<
Tensor
>
in_tensors
;
...
...
@@ -111,7 +94,7 @@ void GatherOpHandle::RunImpl() {
// copy
auto
dev_ctx
=
dev_ctxes_
[
out_place
];
RunAndRecordEvent
(
out_place
,
[
in_tensors
,
out_
va
r
,
dev_ctx
,
out_place
]
{
RunAndRecordEvent
(
out_place
,
[
in_tensors
,
out_
tenso
r
,
dev_ctx
,
out_place
]
{
int
s
=
0
,
e
=
0
;
for
(
size_t
j
=
0
;
j
<
in_tensors
.
size
();
++
j
)
{
e
+=
in_tensors
[
j
].
dims
()[
0
];
...
...
@@ -123,6 +106,27 @@ void GatherOpHandle::RunImpl() {
});
}
void
GatherOpHandle
::
WaitEvents
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
)
{
for
(
auto
*
in
:
in_var_handles
)
{
if
(
in
->
generated_op_
)
{
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
in
->
place_
]);
}
}
}
std
::
vector
<
VarHandle
*>
GatherOpHandle
::
GetValidVarHandles
(
const
std
::
vector
<
VarHandleBase
*>
&
inputs
)
{
std
::
vector
<
VarHandle
*>
in_var_handles
;
for
(
auto
*
in
:
inputs
)
{
auto
*
in_handle
=
dynamic_cast
<
VarHandle
*>
(
in
);
if
(
in_handle
)
{
in_var_handles
.
push_back
(
in_handle
);
}
}
return
in_var_handles
;
}
std
::
string
GatherOpHandle
::
Name
()
const
{
return
"gather"
;
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/gather_op_handle.h
浏览文件 @
4abef501
...
...
@@ -41,6 +41,11 @@ struct GatherOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
std
::
vector
<
VarHandle
*>
GetValidVarHandles
(
const
std
::
vector
<
VarHandleBase
*>
&
);
void
WaitEvents
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
);
};
}
// namespace details
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录