Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3db9fad7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3db9fad7
编写于
11月 08, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into fix_vlog
test=develop
上级
3da43dca
387610aa
变更
37
隐藏空白更改
内联
并排
Showing
37 changed file
with
879 addition
and
213 deletion
+879
-213
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/framework/details/broadcast_op_handle_test.h
paddle/fluid/framework/details/broadcast_op_handle_test.h
+27
-25
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
...uid/framework/details/fast_threaded_ssa_graph_executor.cc
+10
-8
paddle/fluid/framework/details/fetch_op_handle.cc
paddle/fluid/framework/details/fetch_op_handle.cc
+1
-5
paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc
...fluid/framework/details/fused_broadcast_op_handle_test.cc
+18
-16
paddle/fluid/framework/details/gather_op_handle_test.cc
paddle/fluid/framework/details/gather_op_handle_test.cc
+19
-17
paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc
...framework/details/modify_op_lock_and_record_event_pass.cc
+3
-2
paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
...fluid/framework/details/multi_devices_graph_check_pass.cc
+7
-8
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+57
-49
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+12
-4
paddle/fluid/framework/details/multi_devices_graph_print_pass.cc
...fluid/framework/details/multi_devices_graph_print_pass.cc
+2
-1
paddle/fluid/framework/details/multi_devices_helper.h
paddle/fluid/framework/details/multi_devices_helper.h
+3
-12
paddle/fluid/framework/details/op_graph_view.cc
paddle/fluid/framework/details/op_graph_view.cc
+6
-17
paddle/fluid/framework/details/op_graph_view.h
paddle/fluid/framework/details/op_graph_view.h
+2
-7
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+4
-1
paddle/fluid/framework/details/reduce_op_handle_test.cc
paddle/fluid/framework/details/reduce_op_handle_test.cc
+2
-2
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+12
-13
paddle/fluid/framework/details/ssa_graph_executor.cc
paddle/fluid/framework/details/ssa_graph_executor.cc
+4
-2
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+1
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+12
-10
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+7
-7
paddle/fluid/framework/details/var_handle.cc
paddle/fluid/framework/details/var_handle.cc
+6
-0
paddle/fluid/framework/details/var_handle.h
paddle/fluid/framework/details/var_handle.h
+8
-1
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+9
-0
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+9
-0
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+56
-1
paddle/fluid/framework/ir/node_test.cc
paddle/fluid/framework/ir/node_test.cc
+80
-0
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+2
-2
paddle/fluid/inference/api/analysis_predictor_tester.cc
paddle/fluid/inference/api/analysis_predictor_tester.cc
+1
-1
paddle/fluid/operators/space_to_depth_op.cc
paddle/fluid/operators/space_to_depth_op.cc
+131
-0
paddle/fluid/operators/space_to_depth_op.cu
paddle/fluid/operators/space_to_depth_op.cu
+30
-0
paddle/fluid/operators/space_to_depth_op.h
paddle/fluid/operators/space_to_depth_op.h
+127
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+61
-0
python/paddle/fluid/op.py
python/paddle/fluid/op.py
+2
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+11
-0
python/paddle/fluid/tests/unittests/test_space_to_depth_op.py
...on/paddle/fluid/tests/unittests/test_space_to_depth_op.py
+135
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
3db9fad7
...
@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
...
@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
...
...
paddle/fluid/framework/details/broadcast_op_handle_test.h
浏览文件 @
3db9fad7
...
@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
...
@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
param_scopes_
;
std
::
vector
<
Scope
*>
param_scopes_
;
Scope
g_scope_
;
Scope
g_scope_
;
std
::
unique_ptr
<
OpHandleBase
>
op_handle_
;
OpHandleBase
*
op_handle_
;
std
::
vector
<
std
::
unique_ptr
<
VarHandleBase
>>
vars_
;
std
::
vector
<
VarHandleBase
*>
vars_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes_
;
std
::
vector
<
p
::
Place
>
place_list_
;
std
::
vector
<
p
::
Place
>
place_list_
;
bool
use_gpu_
;
bool
use_gpu_
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
...
@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
}
}
void
InitBroadcastOp
(
size_t
input_scope_idx
)
{
void
InitBroadcastOp
(
size_t
input_scope_idx
)
{
nodes_
.
clear
();
for
(
size_t
j
=
0
;
j
<
place_list_
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
place_list_
.
size
();
++
j
)
{
local_scopes_
.
push_back
(
&
(
g_scope_
.
NewScope
()));
local_scopes_
.
push_back
(
&
(
g_scope_
.
NewScope
()));
Scope
&
local_scope
=
local_scopes_
.
back
()
->
NewScope
();
Scope
&
local_scope
=
local_scopes_
.
back
()
->
NewScope
();
...
@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
...
@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
}
}
param_scopes_
[
input_scope_idx
]
->
Var
(
"input"
);
param_scopes_
[
input_scope_idx
]
->
Var
(
"input"
);
std
::
unique_ptr
<
ir
::
Node
>
n
=
nodes_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node0"
,
ir
::
Node
::
Type
::
kOperation
);
ir
::
CreateNodeForTest
(
"node0"
,
ir
::
Node
::
Type
::
kOperation
)
)
;
if
(
use_gpu_
)
{
if
(
use_gpu_
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
op_handle_
.
reset
(
new
BroadcastOpHandle
(
n
.
get
(),
local_scopes_
,
op_handle_
=
new
BroadcastOpHandle
(
nodes_
.
back
()
.
get
(),
local_scopes_
,
place_list_
,
nccl_ctxs_
.
get
()
));
place_list_
,
nccl_ctxs_
.
get
(
));
#else
#else
PADDLE_THROW
(
"CUDA is not support."
);
PADDLE_THROW
(
"CUDA is not support."
);
#endif
#endif
}
else
{
}
else
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
op_handle_
.
reset
(
new
BroadcastOpHandle
(
n
.
get
(),
local_scopes_
,
op_handle_
=
new
BroadcastOpHandle
(
nodes_
.
back
()
.
get
(),
local_scopes_
,
place_list_
,
nccl_ctxs_
.
get
()
));
place_list_
,
nccl_ctxs_
.
get
(
));
#else
#else
op_handle_
.
reset
(
op_handle_
=
new
BroadcastOpHandle
(
nodes_
.
back
().
get
(),
local_scopes_
,
new
BroadcastOpHandle
(
n
.
get
(),
local_scopes_
,
place_list_
)
);
place_list_
);
#endif
#endif
}
}
std
::
unique_ptr
<
ir
::
Node
>
v
=
nodes_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node1"
,
ir
::
Node
::
Type
::
kVariable
);
ir
::
CreateNodeForTest
(
"node1"
,
ir
::
Node
::
Type
::
kVariable
)
)
;
auto
*
in_var_handle
=
new
VarHandle
(
v
.
get
(),
1
,
input_scope_idx
,
"input"
,
auto
*
in_var_handle
=
new
VarHandle
(
nodes_
.
back
().
get
(),
1
,
input_scope_idx
,
place_list_
[
input_scope_idx
]);
"input"
,
place_list_
[
input_scope_idx
]);
vars_
.
emplace_back
(
in_var_handle
);
vars_
.
emplace_back
(
in_var_handle
);
op_handle_
->
AddInput
(
in_var_handle
);
op_handle_
->
AddInput
(
in_var_handle
);
// add dummy var
// add dummy var
std
::
unique_ptr
<
ir
::
Node
>
v2
=
nodes_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node2"
,
ir
::
Node
::
Type
::
kVariable
);
ir
::
CreateNodeForTest
(
"node2"
,
ir
::
Node
::
Type
::
kVariable
)
)
;
vars_
.
emplace_back
(
new
DummyVarHandle
(
v2
.
get
()));
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes_
.
back
()
.
get
()));
DummyVarHandle
*
dummy_var_handle
=
DummyVarHandle
*
dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
()
.
get
()
);
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
());
dummy_var_handle
->
ClearGeneratedOp
();
dummy_var_handle
->
ClearGeneratedOp
();
op_handle_
->
AddInput
(
dummy_var_handle
);
op_handle_
->
AddInput
(
dummy_var_handle
);
...
@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
...
@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
if
(
!
use_gpu_
)
{
if
(
!
use_gpu_
)
{
op_handle_
->
SetDeviceContext
(
place_list_
[
j
],
ctxs_
[
j
].
get
());
op_handle_
->
SetDeviceContext
(
place_list_
[
j
],
ctxs_
[
j
].
get
());
}
}
std
::
unique_ptr
<
ir
::
Node
>
v3
=
nodes_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node3"
,
ir
::
Node
::
Type
::
kVariable
);
ir
::
CreateNodeForTest
(
"node3"
,
ir
::
Node
::
Type
::
kVariable
)
)
;
VarHandle
*
out_var_handle
=
VarHandle
*
out_var_handle
=
new
VarHandle
(
v3
.
get
(),
2
,
j
,
"out"
,
place_list_
[
j
]);
new
VarHandle
(
nodes_
.
back
()
.
get
(),
2
,
j
,
"out"
,
place_list_
[
j
]);
vars_
.
emplace_back
(
out_var_handle
);
vars_
.
emplace_back
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
}
}
// add dummy var
// add dummy var
std
::
unique_ptr
<
ir
::
Node
>
v4
=
nodes_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node4"
,
ir
::
Node
::
Type
::
kVariable
);
ir
::
CreateNodeForTest
(
"node4"
,
ir
::
Node
::
Type
::
kVariable
)
)
;
vars_
.
emplace_back
(
new
DummyVarHandle
(
v4
.
get
()));
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes_
.
back
()
.
get
()));
DummyVarHandle
*
out_dummy_var_handle
=
DummyVarHandle
*
out_dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
()
.
get
()
);
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
());
out_dummy_var_handle
->
ClearGeneratedOp
();
out_dummy_var_handle
->
ClearGeneratedOp
();
op_handle_
->
AddOutput
(
out_dummy_var_handle
);
op_handle_
->
AddOutput
(
out_dummy_var_handle
);
}
}
...
...
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
浏览文件 @
3db9fad7
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
...
@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_
(
strategy
.
num_threads_
+
pool_
(
strategy
.
num_threads_
+
1
),
// add one more thread for generate op_deps
1
),
// add one more thread for generate op_deps
fetch_ctxs_
(
places
)
{
fetch_ctxs_
(
places
)
{
auto
&
ops
=
graph_
->
Get
<
details
::
GraphOps
>
(
"ops"
);
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph_
))
{
for
(
auto
&
op
:
ops
)
{
int
dep
=
static_cast
<
int
>
(
op
->
NotReadyInputSize
());
int
dep
=
static_cast
<
int
>
(
op
->
NotReadyInputSize
());
op_deps_
.
emplace
(
op
.
get
()
,
dep
);
op_deps_
.
emplace
(
op
,
dep
);
if
(
dep
==
0
)
{
if
(
dep
==
0
)
{
bootstrap_ops_
.
emplace_back
(
op
.
get
()
);
bootstrap_ops_
.
emplace_back
(
op
);
}
}
}
}
...
@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
...
@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle
::
framework
::
FeedFetchList
fetches
;
paddle
::
framework
::
FeedFetchList
fetches
;
fetches
.
resize
(
fetch_tensors
.
size
());
fetches
.
resize
(
fetch_tensors
.
size
());
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>
>
fetch_ops
;
std
::
vector
<
FetchOpHandle
*
>
fetch_ops
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
fetched_vars
[
fetch_var_name
].
push_back
(
*
it
->
second
.
rbegin
());
}
}
}
}
}
}
...
@@ -110,7 +109,10 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
...
@@ -110,7 +109,10 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
complete_q
->
Pop
();
complete_q
->
Pop
();
}
}
}
}
exception_
.
ReThrow
();
if
(
exception_
.
IsCaught
())
{
ClearFetchOp
(
graph_
.
get
(),
&
fetch_ops
);
exception_
.
ReThrow
();
}
}
}
num_complete
+=
num_comp
;
num_complete
+=
num_comp
;
}
}
...
...
paddle/fluid/framework/details/fetch_op_handle.cc
浏览文件 @
3db9fad7
...
@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
...
@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
offset_
(
offset
),
offset_
(
offset
),
local_scopes_
(
local_scopes
)
{}
local_scopes_
(
local_scopes
)
{}
FetchOpHandle
::~
FetchOpHandle
()
{
FetchOpHandle
::~
FetchOpHandle
()
{}
for
(
auto
*
input_var
:
inputs_
)
{
input_var
->
RemoveOutput
(
this
,
this
->
Node
());
}
}
void
FetchOpHandle
::
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
{
void
FetchOpHandle
::
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
{
PADDLE_THROW
(
"Nobody should wait FetchOp. Unexpceted Error"
);
PADDLE_THROW
(
"Nobody should wait FetchOp. Unexpceted Error"
);
...
...
paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc
浏览文件 @
3db9fad7
...
@@ -22,8 +22,10 @@ namespace details {
...
@@ -22,8 +22,10 @@ namespace details {
struct
TestFusedBroadcastOpHandle
:
TestBroadcastOpHandle
{
struct
TestFusedBroadcastOpHandle
:
TestBroadcastOpHandle
{
std
::
vector
<
std
::
string
>
out_varnames_
;
std
::
vector
<
std
::
string
>
out_varnames_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes_
;
void
InitFusedBroadcastOp
(
std
::
vector
<
size_t
>
input_scope_idxes
)
{
void
InitFusedBroadcastOp
(
std
::
vector
<
size_t
>
input_scope_idxes
)
{
nodes_
.
clear
();
// initialize scope and var
// initialize scope and var
for
(
size_t
i
=
0
;
i
<
place_list_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
place_list_
.
size
();
++
i
)
{
local_scopes_
.
push_back
(
&
(
g_scope_
.
NewScope
()));
local_scopes_
.
push_back
(
&
(
g_scope_
.
NewScope
()));
...
@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
...
@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
}
}
// create op handle node
// create op handle node
std
::
unique_ptr
<
ir
::
Node
>
n
=
nodes_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"fused_broadcast"
,
ir
::
Node
::
Type
::
kOperation
);
ir
::
CreateNodeForTest
(
"fused_broadcast"
,
ir
::
Node
::
Type
::
kOperation
)
)
;
if
(
use_gpu_
)
{
if
(
use_gpu_
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
op_handle_
.
reset
(
new
FusedBroadcastOpHandle
(
op_handle_
=
new
FusedBroadcastOpHandle
(
n
.
get
(),
local_scopes_
,
place_list_
,
nccl_ctxs_
.
get
()
));
n
odes_
.
back
().
get
(),
local_scopes_
,
place_list_
,
nccl_ctxs_
.
get
(
));
#else
#else
PADDLE_THROW
(
"CUDA is not supported."
);
PADDLE_THROW
(
"CUDA is not supported."
);
#endif
#endif
}
else
{
}
else
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
op_handle_
.
reset
(
new
FusedBroadcastOpHandle
(
op_handle_
=
new
FusedBroadcastOpHandle
(
n
.
get
(),
local_scopes_
,
place_list_
,
nccl_ctxs_
.
get
()
));
n
odes_
.
back
().
get
(),
local_scopes_
,
place_list_
,
nccl_ctxs_
.
get
(
));
#else
#else
op_handle_
.
reset
(
op_handle_
=
new
FusedBroadcastOpHandle
(
nodes_
.
back
().
get
(),
new
FusedBroadcastOpHandle
(
n
.
get
(),
local_scopes_
,
place_list_
)
);
local_scopes_
,
place_list_
);
#endif
#endif
}
}
for
(
size_t
i
=
0
;
i
<
input_scope_idxes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_scope_idxes
.
size
();
++
i
)
{
// add input var handle
// add input var handle
std
::
unique_ptr
<
ir
::
Node
>
in_node
=
nodes_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"in_node"
+
i
,
ir
::
Node
::
Type
::
kVariable
);
ir
::
CreateNodeForTest
(
"in_node"
+
i
,
ir
::
Node
::
Type
::
kVariable
)
)
;
VarHandle
*
in_var_handle
=
VarHandle
*
in_var_handle
=
new
VarHandle
(
in_node
.
get
(),
1
,
input_scope_idxes
[
i
],
"in_var"
+
i
,
new
VarHandle
(
nodes_
.
back
().
get
(),
1
,
input_scope_idxes
[
i
]
,
place_list_
[
input_scope_idxes
[
i
]]);
"in_var"
+
i
,
place_list_
[
input_scope_idxes
[
i
]]);
vars_
.
emplace_back
(
in_var_handle
);
vars_
.
emplace_back
(
in_var_handle
);
op_handle_
->
AddInput
(
in_var_handle
);
op_handle_
->
AddInput
(
in_var_handle
);
// add output var handle
// add output var handle
for
(
size_t
j
=
0
;
j
<
place_list_
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
place_list_
.
size
();
++
j
)
{
std
::
unique_ptr
<
ir
::
Node
>
out_node
=
nodes_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"out_node"
+
i
,
ir
::
Node
::
Type
::
kVariable
);
ir
::
CreateNodeForTest
(
"out_node"
+
i
,
ir
::
Node
::
Type
::
kVariable
)
)
;
VarHandle
*
out_var_handle
=
VarHandle
*
out_var_handle
=
new
VarHandle
(
n
ew
VarHandle
(
out_node
.
get
(),
2
,
j
,
"out_var"
+
i
,
place_list_
[
j
]);
n
odes_
.
back
()
.
get
(),
2
,
j
,
"out_var"
+
i
,
place_list_
[
j
]);
vars_
.
emplace_back
(
out_var_handle
);
vars_
.
emplace_back
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
}
}
...
...
paddle/fluid/framework/details/gather_op_handle_test.cc
浏览文件 @
3db9fad7
...
@@ -31,9 +31,10 @@ struct TestGatherOpHandle {
...
@@ -31,9 +31,10 @@ struct TestGatherOpHandle {
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
param_scopes_
;
std
::
vector
<
Scope
*>
param_scopes_
;
Scope
g_scope_
;
Scope
g_scope_
;
std
::
unique_ptr
<
OpHandleBase
>
op_handle_
;
OpHandleBase
*
op_handle_
;
std
::
vector
<
std
::
unique_ptr
<
VarHandleBase
>
>
vars_
;
std
::
vector
<
VarHandleBase
*
>
vars_
;
std
::
vector
<
p
::
Place
>
gpu_list_
;
std
::
vector
<
p
::
Place
>
gpu_list_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes_
;
void
WaitAll
()
{
void
WaitAll
()
{
for
(
size_t
j
=
0
;
j
<
ctxs_
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
ctxs_
.
size
();
++
j
)
{
...
@@ -70,7 +71,7 @@ struct TestGatherOpHandle {
...
@@ -70,7 +71,7 @@ struct TestGatherOpHandle {
}
}
void
InitGatherOp
(
size_t
input_scope_idx
)
{
void
InitGatherOp
(
size_t
input_scope_idx
)
{
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
nodes_
.
clear
()
;
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
local_scopes_
.
push_back
(
&
(
g_scope_
.
NewScope
()));
local_scopes_
.
push_back
(
&
(
g_scope_
.
NewScope
()));
Scope
&
local_scope
=
local_scopes_
.
back
()
->
NewScope
();
Scope
&
local_scope
=
local_scopes_
.
back
()
->
NewScope
();
...
@@ -82,44 +83,45 @@ struct TestGatherOpHandle {
...
@@ -82,44 +83,45 @@ struct TestGatherOpHandle {
}
}
param_scopes_
[
input_scope_idx
]
->
Var
(
"out"
);
param_scopes_
[
input_scope_idx
]
->
Var
(
"out"
);
nodes
.
emplace_back
(
nodes
_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node"
,
ir
::
Node
::
Type
::
kOperation
).
release
());
ir
::
CreateNodeForTest
(
"node"
,
ir
::
Node
::
Type
::
kOperation
).
release
());
op_handle_
.
reset
(
op_handle_
=
new
GatherOpHandle
(
nodes
.
back
().
get
(),
local_scopes_
,
gpu_list_
)
);
new
GatherOpHandle
(
nodes
_
.
back
().
get
(),
local_scopes_
,
gpu_list_
);
// add input
// add input
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
op_handle_
->
SetDeviceContext
(
gpu_list_
[
j
],
ctxs_
[
j
].
get
());
op_handle_
->
SetDeviceContext
(
gpu_list_
[
j
],
ctxs_
[
j
].
get
());
nodes
.
emplace_back
(
nodes
_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node1"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
ir
::
CreateNodeForTest
(
"node1"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
auto
*
in_var_handle
=
auto
*
in_var_handle
=
new
VarHandle
(
nodes
.
back
().
get
(),
1
,
j
,
"input"
,
gpu_list_
[
j
]);
new
VarHandle
(
nodes
_
.
back
().
get
(),
1
,
j
,
"input"
,
gpu_list_
[
j
]);
vars_
.
emplace_back
(
in_var_handle
);
vars_
.
emplace_back
(
in_var_handle
);
op_handle_
->
AddInput
(
in_var_handle
);
op_handle_
->
AddInput
(
in_var_handle
);
}
}
// add dummy var
// add dummy var
nodes
.
emplace_back
(
nodes
_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node2"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
ir
::
CreateNodeForTest
(
"node2"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
.
back
().
get
()));
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
_
.
back
().
get
()));
DummyVarHandle
*
in_dummy_var_handle
=
DummyVarHandle
*
in_dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
()
.
get
()
);
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
());
in_dummy_var_handle
->
ClearGeneratedOp
();
in_dummy_var_handle
->
ClearGeneratedOp
();
op_handle_
->
AddInput
(
in_dummy_var_handle
);
op_handle_
->
AddInput
(
in_dummy_var_handle
);
// add output
// add output
nodes
.
emplace_back
(
nodes
_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node3"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
ir
::
CreateNodeForTest
(
"node3"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
auto
*
out_var_handle
=
new
VarHandle
(
nodes
.
back
().
get
(),
2
,
input_scope_idx
,
auto
*
out_var_handle
=
"out"
,
gpu_list_
[
input_scope_idx
]);
new
VarHandle
(
nodes_
.
back
().
get
(),
2
,
input_scope_idx
,
"out"
,
gpu_list_
[
input_scope_idx
]);
vars_
.
emplace_back
(
out_var_handle
);
vars_
.
emplace_back
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
// add dummy var
// add dummy var
nodes
.
emplace_back
(
nodes
_
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node4"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
ir
::
CreateNodeForTest
(
"node4"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
.
back
().
get
()));
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
_
.
back
().
get
()));
DummyVarHandle
*
dummy_var_handle
=
DummyVarHandle
*
dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
()
.
get
()
);
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
());
op_handle_
->
AddOutput
(
dummy_var_handle
);
op_handle_
->
AddOutput
(
dummy_var_handle
);
}
}
...
...
paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc
浏览文件 @
3db9fad7
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
...
@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
std
::
unique_ptr
<
ir
::
Graph
>
ModifyOpLockAndRecordEventPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ModifyOpLockAndRecordEventPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ir_graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
ir_graph
)
const
{
auto
&
all_ops
=
ir_graph
->
Get
<
GraphOps
>
(
kGraphOps
);
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
ir_graph
);
OpGraphView
graph_view
(
all_ops
);
OpGraphView
graph_view
(
all_ops
);
for
(
auto
&
op
:
all_ops
)
{
for
(
auto
&
op
:
all_ops
)
{
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
()
);
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
==
nullptr
)
continue
;
if
(
compute_op
==
nullptr
)
continue
;
bool
is_lock_and_record_event_free
=
bool
is_lock_and_record_event_free
=
IsLockAndRecordEventFreeComputationOpHandle
(
compute_op
,
graph_view
);
IsLockAndRecordEventFreeComputationOpHandle
(
compute_op
,
graph_view
);
...
...
paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
浏览文件 @
3db9fad7
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
...
@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
insert_pending_var
(
version_pair
.
get
()
);
insert_pending_var
(
version_pair
);
}
}
}
}
}
}
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
insert_pending_var
(
var
.
get
()
);
insert_pending_var
(
var
);
}
}
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
for
(
OpHandleBase
*
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
))
{
if
(
op
->
Inputs
().
empty
())
{
if
(
op
->
Inputs
().
empty
())
{
ready_ops
.
insert
(
op
.
get
()
);
ready_ops
.
insert
(
op
);
}
else
{
}
else
{
pending_ops
.
insert
({
op
.
get
(),
op
.
get
()
->
NoDupInputSize
()});
pending_ops
.
insert
({
op
,
op
->
NoDupInputSize
()});
}
}
}
}
...
@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
...
@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
REGISTER_PASS
(
multi_devices_check_pass
,
REGISTER_PASS
(
multi_devices_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
);
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphOps
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kShardedVarDevice
);
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
3db9fad7
...
@@ -34,7 +34,14 @@
...
@@ -34,7 +34,14 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
namespace
{
namespace
{
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef
std
::
vector
<
OpHandleBase
*>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
...
@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
...
@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
}
}
var_holder
.
emplace_back
(
var
);
var_holder
.
emplace_back
(
var
);
}
else
{
}
else
{
var
=
var_holder
.
rbegin
()
->
get
();
var
=
*
var_holder
.
rbegin
();
}
}
return
var
;
return
var
;
}
}
...
@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
...
@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
op_handle
->
SetDeviceContext
(
p
,
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kShardedVarDevice
,
new
ShardedVarDevice
);
// find send/recv vars so that we can place the distributed training
// find send/recv vars so that we can place the distributed training
// related op in the place 0
// related op in the place 0
...
@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
bool
is_dist_train
=
false
;
bool
is_dist_train
=
false
;
std
::
unordered_map
<
std
::
string
,
int
>
sharded_var_device
;
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
if
(
boost
::
get
<
int
>
(
if
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
int
op_dev_id
=
CreateRPCOp
(
&
result
,
node
);
int
op_dev_id
=
CreateRPCOp
(
&
result
,
node
,
&
sharded_var_device
);
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"Can not schedule the RPC operator to the right place."
);
"Can not schedule the RPC operator to the right place."
);
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
...
@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
else
if
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
}
else
if
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kDist
))
{
static_cast
<
int
>
(
OpRole
::
kDist
))
{
int
op_dev_id
=
CreateDistTrainOp
(
&
result
,
node
);
int
op_dev_id
=
CreateDistTrainOp
(
&
result
,
node
,
&
sharded_var_device
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
bcast_var_name_set
[
op_dev_id
].
emplace
(
origin_param_name
);
bcast_var_name_set
[
op_dev_id
].
emplace
(
origin_param_name
);
...
@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block.
// the block.
is_forwarding
=
false
;
is_forwarding
=
false
;
}
else
{
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
result
,
node
);
int
op_dev_id
=
GetOpDeviceID
(
result
,
node
,
sharded_var_device
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
sharded_var_device
.
emplace
(
n
->
Name
(),
op_dev_id
);
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
}
}
else
{
}
else
{
// This op runs on all devices, and its output may have parameter's
// This op runs on all devices, and its output may have parameter's
...
@@ -398,8 +405,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -398,8 +405,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
sharded_var_device
.
emplace
(
g_name
,
cur_device_id
);
.
emplace
(
g_name
,
cur_device_id
);
if
(
!
is_dist_train
)
{
if
(
!
is_dist_train
)
{
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
}
}
...
@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
* Only variables should be the leaves of graph.
* Only variables should be the leaves of graph.
*/
*/
AddOutputToLeafOps
(
&
result
);
AddOutputToLeafOps
(
&
result
);
PADDLE_ENFORCE
(
!
ir
::
HasCircle
(
result
)
);
result
.
Erase
<
GraphOps
>
(
kGraphOps
);
return
graph
;
return
graph
;
}
}
...
@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
...
@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
auto
*
in
=
auto
*
in
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
()
.
get
()
;
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
();
op_handle
->
AddInput
(
in
);
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
...
@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
...
@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_varnames
.
size
();
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_varnames
.
size
();
++
dev_id
)
{
for
(
auto
&
p_name
:
bcast_varnames
[
dev_id
])
{
for
(
auto
&
p_name
:
bcast_varnames
[
dev_id
])
{
auto
*
in
=
auto
*
in
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
dev_id
).
at
(
p_name
).
back
()
.
get
()
;
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
dev_id
).
at
(
p_name
).
back
();
op_handle
->
AddInput
(
in
);
op_handle
->
AddInput
(
in
);
for
(
size_t
out_dev_id
=
0
;
out_dev_id
<
places_
.
size
();
++
out_dev_id
)
{
for
(
size_t
out_dev_id
=
0
;
out_dev_id
<
places_
.
size
();
++
out_dev_id
)
{
auto
&
p
=
places_
[
out_dev_id
];
auto
&
p
=
places_
[
out_dev_id
];
...
@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
...
@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
...
@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
...
@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
()
);
op_handle
->
AddInput
(
prev_grad
);
auto
var
=
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
...
@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
...
@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
d_name
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
()
.
get
()
);
op_handle
->
AddInput
(
vars
.
back
());
auto
var
=
new
VarHandle
(
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
d_name
,
ir
::
Node
::
Type
::
kVariable
),
result
->
CreateEmptyNode
(
d_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
d_name
,
p
);
vars
.
size
(),
i
,
d_name
,
p
);
...
@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
...
@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
}
}
}
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
)
const
{
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
return
-
1
;
}
}
...
@@ -631,15 +638,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
...
@@ -631,15 +638,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
graph
,
param_grad
[
1
]);
int
dev_id
=
GetVarDeviceID
(
graph
,
param_grad
[
1
]
,
sharded_var_device
);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
return
dev_id
;
return
dev_id
;
}
}
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
{
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
,
auto
&
sharded_var_device
=
graph
.
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
);
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
{
auto
got
=
sharded_var_device
.
find
(
varname
);
auto
got
=
sharded_var_device
.
find
(
varname
);
return
got
==
sharded_var_device
.
end
()
?
-
1
:
got
->
second
;
return
got
==
sharded_var_device
.
end
()
?
-
1
:
got
->
second
;
}
}
...
@@ -690,7 +697,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -690,7 +697,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
...
@@ -698,7 +705,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -698,7 +705,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
()
);
op_handle
->
AddInput
(
prev_grad
);
}
}
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
dst_dev_id
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
dst_dev_id
][
og
];
auto
var
=
auto
var
=
...
@@ -709,8 +716,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -709,8 +716,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return
var
;
return
var
;
}
}
int
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
int
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Node
*
node
)
const
{
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
...
@@ -725,23 +733,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -725,23 +733,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node
->
Op
()
->
Type
()
==
"split_selected_rows"
||
node
->
Op
()
->
Type
()
==
"split_selected_rows"
||
node
->
Op
()
->
Type
()
==
"split_ids"
)
{
node
->
Op
()
->
Type
()
==
"split_ids"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
],
*
sharded_var_device
);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
.
emplace
(
varname
,
op_dev_id
);
}
}
}
}
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
],
*
sharded_var_device
);
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
{
}
else
{
LOG
(
ERROR
)
<<
"got unexpected dist op: "
<<
node
->
Op
()
->
Type
();
LOG
(
ERROR
)
<<
"got unexpected dist op: "
<<
node
->
Op
()
->
Type
();
...
@@ -759,14 +766,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -759,14 +766,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
}
}
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
VarHandle
*
var
=
nullptr
;
VarHandle
*
var
=
nullptr
;
for
(
int
place_offset
=
0
;
place_offset
<
num_places
;
++
place_offset
)
{
for
(
int
place_offset
=
0
;
place_offset
<
num_places
;
++
place_offset
)
{
auto
&
var_holders
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holders
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
input
->
Name
()];
auto
&
var_holder
=
var_holders
[
input
->
Name
()];
if
(
!
var_holder
.
empty
())
{
if
(
!
var_holder
.
empty
())
{
var
=
var_holder
.
rbegin
()
->
get
();
var
=
*
var_holder
.
rbegin
();
op_handle
->
AddInput
(
var
);
op_handle
->
AddInput
(
var
);
}
}
}
}
...
@@ -774,12 +781,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
...
@@ -774,12 +781,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
}
// Create RPC related op handles that connects its in ops and out ops.
// Create RPC related op handles that connects its in ops and out ops.
int
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
int
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Node
*
node
)
const
{
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
*
result
,
node
->
inputs
[
0
]
->
Name
());
op_dev_id
=
GetVarDeviceID
(
*
result
,
node
->
inputs
[
0
]
->
Name
(),
*
sharded_var_device
);
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
"This hack no longer holds, please fix."
);
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
// the variable name which contains .block means it was splited by
...
@@ -797,11 +806,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -797,11 +806,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
VLOG
(
100
)
<<
"send grad "
<<
input_var_names
[
0
]
<<
" origin "
VLOG
(
100
)
<<
"send grad "
<<
input_var_names
[
0
]
<<
" origin "
<<
send_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
<<
send_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
.
emplace
(
varname
,
op_dev_id
);
}
}
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
sharded_var_device
->
emplace
(
send_param_grad
[
1
],
op_dev_id
);
.
emplace
(
send_param_grad
[
1
],
op_dev_id
);
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
std
::
vector
<
std
::
string
>
output_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
...
@@ -811,7 +818,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -811,7 +818,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
auto
recv_param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
auto
recv_param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
if
(
recv_param_grad
.
size
()
==
2U
)
{
if
(
recv_param_grad
.
size
()
==
2U
)
{
op_dev_id
=
GetVarDeviceID
(
*
result
,
recv_param_grad
[
1
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
recv_param_grad
[
1
],
*
sharded_var_device
);
VLOG
(
100
)
<<
"recv param "
<<
recv_param_grad
[
0
]
VLOG
(
100
)
<<
"recv param "
<<
recv_param_grad
[
0
]
<<
" get grad place: "
<<
recv_param_grad
[
1
]
<<
" get grad place: "
<<
recv_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
<<
" place: "
<<
op_dev_id
;
...
@@ -819,8 +827,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -819,8 +827,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
}
}
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
{
}
else
{
// send_barrier, fetch_barrier will run on place 0;
// send_barrier, fetch_barrier will run on place 0;
...
@@ -839,7 +846,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -839,7 +846,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from
// all places
// all places
auto
p
=
places_
[
op_dev_id
];
auto
p
=
places_
[
op_dev_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
op_handle
->
SetDeviceContext
(
p
,
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
@@ -847,7 +854,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -847,7 +854,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
int
outvar_dev_id
=
op_dev_id
;
int
outvar_dev_id
=
op_dev_id
;
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
outvar_dev_id
=
GetVarDeviceID
(
*
result
,
output
->
Name
());
outvar_dev_id
=
GetVarDeviceID
(
*
result
,
output
->
Name
(),
*
sharded_var_device
);
PADDLE_ENFORCE_NE
(
outvar_dev_id
,
-
1
);
PADDLE_ENFORCE_NE
(
outvar_dev_id
,
-
1
);
}
}
p
=
places_
[
outvar_dev_id
];
p
=
places_
[
outvar_dev_id
];
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
3db9fad7
...
@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
#endif
int
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
)
const
;
int
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
int
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
int
CreateRPCOp
(
int
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
;
int
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
const
std
::
vector
<
ir
::
Node
*>
&
nodes
)
const
;
const
std
::
vector
<
ir
::
Node
*>
&
nodes
)
const
;
...
@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
;
int
dev_id
)
const
;
int
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
;
int
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
...
...
paddle/fluid/framework/details/multi_devices_graph_print_pass.cc
浏览文件 @
3db9fad7
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
...
@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
});
size_t
op_id
=
0
;
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
kGraphOps
))
{
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
graph
))
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
paddle/fluid/framework/details/multi_devices_helper.h
浏览文件 @
3db9fad7
...
@@ -35,23 +35,14 @@ namespace details {
...
@@ -35,23 +35,14 @@ namespace details {
// The outside vector is the device vector. Each element of this vector is a
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
// `std::vector<VarHandle*>` is the version of varaibles.
typedef
std
::
vector
<
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandle
*>>>
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
GraphVars
;
GraphVars
;
const
char
kGraphVars
[]
=
"vars"
;
const
char
kGraphVars
[]
=
"vars"
;
// aux variables to represent dependency. Useful to resolve data hazard.
// aux variables to represent dependency. Useful to resolve data hazard.
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>
>
GraphDepVars
;
typedef
std
::
unordered_set
<
VarHandleBase
*
>
GraphDepVars
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
typedef
std
::
unordered_map
<
std
::
string
,
int
>
ShardedVarDevice
;
const
char
kShardedVarDevice
[]
=
"sharded_var_device"
;
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/op_graph_view.cc
浏览文件 @
3db9fad7
...
@@ -20,19 +20,16 @@ namespace paddle {
...
@@ -20,19 +20,16 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
OpGraphView
::
OpGraphView
(
OpGraphView
::
OpGraphView
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
{
Build
(
ops
);
}
const
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
&
ops
)
{
Build
(
ops
);
}
void
OpGraphView
::
Build
(
const
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>
>
&
ops
)
{
void
OpGraphView
::
Build
(
const
std
::
vector
<
OpHandleBase
*
>
&
ops
)
{
for
(
auto
&
op
:
ops
)
{
for
(
auto
&
op
:
ops
)
{
preceding_ops_
[
op
.
get
()
];
preceding_ops_
[
op
];
pending_ops_
[
op
.
get
()
];
pending_ops_
[
op
];
for
(
auto
&
var
:
op
->
Outputs
())
{
for
(
auto
&
var
:
op
->
Outputs
())
{
for
(
auto
&
pending_op
:
var
->
PendingOps
())
{
for
(
auto
&
pending_op
:
var
->
PendingOps
())
{
preceding_ops_
[
pending_op
].
insert
(
op
.
get
()
);
preceding_ops_
[
pending_op
].
insert
(
op
);
pending_ops_
[
op
.
get
()
].
insert
(
pending_op
);
pending_ops_
[
op
].
insert
(
pending_op
);
}
}
}
}
}
}
...
@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
...
@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
"There are duplicate ops in graph."
);
"There are duplicate ops in graph."
);
}
}
size_t
OpGraphView
::
OpNumber
()
const
{
return
preceding_ops_
.
size
();
}
std
::
unordered_set
<
OpHandleBase
*>
OpGraphView
::
AllOps
()
const
{
std
::
unordered_set
<
OpHandleBase
*>
OpGraphView
::
AllOps
()
const
{
std
::
unordered_set
<
OpHandleBase
*>
ret
;
std
::
unordered_set
<
OpHandleBase
*>
ret
;
for
(
auto
&
pair
:
preceding_ops_
)
{
for
(
auto
&
pair
:
preceding_ops_
)
{
...
@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
...
@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
op
==
nullptr
?
"nullptr"
:
op
->
DebugString
());
op
==
nullptr
?
"nullptr"
:
op
->
DebugString
());
}
}
const
std
::
unordered_set
<
OpHandleBase
*>
&
OpGraphView
::
PrecedingOps
(
OpHandleBase
*
op
)
const
{
EnforceHasOp
(
op
);
return
preceding_ops_
.
at
(
op
);
}
const
std
::
unordered_set
<
OpHandleBase
*>
&
OpGraphView
::
PendingOps
(
const
std
::
unordered_set
<
OpHandleBase
*>
&
OpGraphView
::
PendingOps
(
OpHandleBase
*
op
)
const
{
OpHandleBase
*
op
)
const
{
EnforceHasOp
(
op
);
EnforceHasOp
(
op
);
...
...
paddle/fluid/framework/details/op_graph_view.h
浏览文件 @
3db9fad7
...
@@ -26,21 +26,16 @@ namespace details {
...
@@ -26,21 +26,16 @@ namespace details {
class
OpGraphView
{
class
OpGraphView
{
public:
public:
explicit
OpGraphView
(
const
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
&
ops
);
explicit
OpGraphView
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
);
size_t
OpNumber
()
const
;
std
::
unordered_set
<
OpHandleBase
*>
AllOps
()
const
;
std
::
unordered_set
<
OpHandleBase
*>
AllOps
()
const
;
const
std
::
unordered_set
<
OpHandleBase
*>
&
PrecedingOps
(
OpHandleBase
*
op
)
const
;
const
std
::
unordered_set
<
OpHandleBase
*>
&
PendingOps
(
OpHandleBase
*
op
)
const
;
const
std
::
unordered_set
<
OpHandleBase
*>
&
PendingOps
(
OpHandleBase
*
op
)
const
;
bool
HasOp
(
OpHandleBase
*
op
)
const
;
bool
HasOp
(
OpHandleBase
*
op
)
const
;
private:
private:
void
Build
(
const
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>
>
&
ops
);
void
Build
(
const
std
::
vector
<
OpHandleBase
*
>
&
ops
);
void
EnforceHasOp
(
OpHandleBase
*
op
)
const
;
void
EnforceHasOp
(
OpHandleBase
*
op
)
const
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unordered_set
<
OpHandleBase
*>>
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unordered_set
<
OpHandleBase
*>>
...
...
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
3db9fad7
...
@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
...
@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node.
// It's responsible for populating necessary fields of ir::Node.
class
OpHandleBase
{
class
OpHandleBase
{
public:
public:
explicit
OpHandleBase
(
ir
::
Node
*
node
)
:
node_
(
node
)
{}
// Owned by `node`. No need to be deleted explicitly.
explicit
OpHandleBase
(
ir
::
Node
*
node
)
:
node_
(
node
)
{
node_
->
WrappedBy
(
this
);
}
virtual
~
OpHandleBase
();
virtual
~
OpHandleBase
();
...
...
paddle/fluid/framework/details/reduce_op_handle_test.cc
浏览文件 @
3db9fad7
...
@@ -30,8 +30,8 @@ struct TestReduceOpHandle {
...
@@ -30,8 +30,8 @@ struct TestReduceOpHandle {
Scope
g_scope_
;
Scope
g_scope_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
param_scopes_
;
std
::
vector
<
Scope
*>
param_scopes_
;
std
::
unique_ptr
<
OpHandleBase
>
op_handle_
;
OpHandleBase
*
op_handle_
;
std
::
vector
<
std
::
unique_ptr
<
VarHandleBase
>
>
vars_
;
std
::
vector
<
VarHandleBase
*
>
vars_
;
std
::
vector
<
p
::
Place
>
gpu_list_
;
std
::
vector
<
p
::
Place
>
gpu_list_
;
std
::
vector
<
std
::
unique_ptr
<
p
::
DeviceContext
>>
ctxs_
;
std
::
vector
<
std
::
unique_ptr
<
p
::
DeviceContext
>>
ctxs_
;
...
...
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
3db9fad7
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -71,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -71,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
// in computation ops
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unique_ptr
<
ReferenceCountOpHandle
>
>
std
::
unordered_map
<
OpHandleBase
*
,
ReferenceCountOpHandle
*
>
compute_ref_cnt_map
;
compute_ref_cnt_map
;
auto
get_ref_cnts_from_compute_op
=
[
&
](
auto
get_ref_cnts_from_compute_op
=
[
&
](
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
OpHandleBase
*
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
std
::
vector
<
std
::
string
>
var_names_in_op
;
std
::
vector
<
std
::
string
>
var_names_in_op
;
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
()
);
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
==
nullptr
||
if
(
compute_op
==
nullptr
||
!
platform
::
is_gpu_place
(
compute_op
->
GetPlace
()))
!
platform
::
is_gpu_place
(
compute_op
->
GetPlace
()))
return
var_names_in_op
;
return
var_names_in_op
;
...
@@ -121,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -121,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
};
};
auto
update_ref_cnts_from_non_compute_op
=
[
&
](
auto
update_ref_cnts_from_non_compute_op
=
[
&
](
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
OpHandleBase
*
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
if
(
dynamic_cast
<
ComputationOpHandle
*>
(
op
)
!=
nullptr
)
return
;
if
(
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
())
!=
nullptr
)
return
;
for
(
VarHandleBase
*
var_handle_base
:
vars
)
{
for
(
VarHandleBase
*
var_handle_base
:
vars
)
{
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var_handle_base
);
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var_handle_base
);
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
())
continue
;
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
())
continue
;
...
@@ -151,21 +150,21 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -151,21 +150,21 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node
,
next_compute_op
->
GetScope
(),
place
,
{
var_name
},
ref_cnt_node
,
next_compute_op
->
GetScope
(),
place
,
{
var_name
},
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
AddDependencyBetween
(
next_compute_op
,
ref_cnt_handle
,
graph
.
get
());
AddDependencyBetween
(
next_compute_op
,
ref_cnt_handle
,
graph
.
get
());
compute_ref_cnt_map
[
next_compute_op
]
.
reset
(
ref_cnt_handle
)
;
compute_ref_cnt_map
[
next_compute_op
]
=
ref_cnt_handle
;
}
}
}
}
}
}
}
}
};
};
auto
&
all_ops
=
graph
->
Get
<
GraphOps
>
(
kGraphOps
);
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
);
for
(
auto
&
op
:
all_ops
)
{
for
(
auto
&
op
:
all_ops
)
{
auto
in_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Inputs
());
auto
in_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Inputs
());
auto
out_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Outputs
());
auto
out_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Outputs
());
if
(
in_var_names
.
empty
()
&&
out_var_names
.
empty
())
continue
;
if
(
in_var_names
.
empty
()
&&
out_var_names
.
empty
())
continue
;
in_var_names
.
insert
(
in_var_names
.
end
(),
out_var_names
.
begin
(),
in_var_names
.
insert
(
in_var_names
.
end
(),
out_var_names
.
begin
(),
out_var_names
.
end
());
out_var_names
.
end
());
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
()
);
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
compute_op
->
GetPlace
());
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
compute_op
->
GetPlace
());
ir
::
Node
*
ref_cnt_node
=
ir
::
Node
*
ref_cnt_node
=
graph
->
CreateEmptyNode
(
"reference_count"
,
ir
::
Node
::
Type
::
kOperation
);
graph
->
CreateEmptyNode
(
"reference_count"
,
ir
::
Node
::
Type
::
kOperation
);
...
@@ -173,7 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -173,7 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
AddDependencyBetween
(
compute_op
,
ref_cnt_handle
,
graph
.
get
());
AddDependencyBetween
(
compute_op
,
ref_cnt_handle
,
graph
.
get
());
compute_ref_cnt_map
[
compute_op
]
.
reset
(
ref_cnt_handle
)
;
compute_ref_cnt_map
[
compute_op
]
=
ref_cnt_handle
;
}
}
for
(
auto
&
op
:
all_ops
)
{
for
(
auto
&
op
:
all_ops
)
{
...
@@ -181,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -181,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
update_ref_cnts_from_non_compute_op
(
op
,
op
->
Outputs
());
update_ref_cnts_from_non_compute_op
(
op
,
op
->
Outputs
());
}
}
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>
>
new_all_ops
;
std
::
vector
<
OpHandleBase
*
>
new_all_ops
;
new_all_ops
.
reserve
(
compute_ref_cnt_map
.
size
()
+
all_ops
.
size
());
new_all_ops
.
reserve
(
compute_ref_cnt_map
.
size
()
+
all_ops
.
size
());
for
(
auto
&
op
:
all_ops
)
{
for
(
auto
&
op
:
all_ops
)
{
new_all_ops
.
emplace_back
(
std
::
move
(
op
));
new_all_ops
.
emplace_back
(
std
::
move
(
op
));
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
()
.
get
()
);
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
());
if
(
it
!=
compute_ref_cnt_map
.
end
())
{
if
(
it
!=
compute_ref_cnt_map
.
end
())
{
// Add LeafNode to ReferenceCountOpHandle
// Add LeafNode to ReferenceCountOpHandle
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
...
...
paddle/fluid/framework/details/ssa_graph_executor.cc
浏览文件 @
3db9fad7
...
@@ -19,14 +19,16 @@ namespace framework {
...
@@ -19,14 +19,16 @@ namespace framework {
namespace
details
{
namespace
details
{
SSAGraphExecutor
::~
SSAGraphExecutor
()
{}
SSAGraphExecutor
::~
SSAGraphExecutor
()
{}
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
FetchOpHandle
*>*
fetch_ops
)
{
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>*
fetch_ops
)
{
if
(
fetch_ops
->
empty
())
return
;
if
(
fetch_ops
->
empty
())
return
;
for
(
auto
&
op
:
*
fetch_ops
)
{
for
(
auto
&
op
:
*
fetch_ops
)
{
for
(
auto
&
out_var
:
op
->
Node
()
->
outputs
)
{
for
(
auto
&
out_var
:
op
->
Node
()
->
outputs
)
{
graph
->
RemoveNode
(
out_var
);
graph
->
RemoveNode
(
out_var
);
}
}
for
(
auto
&
in_var
:
op
->
Inputs
())
{
in_var
->
RemoveOutput
(
op
,
op
->
Node
());
}
graph
->
RemoveNode
(
op
->
Node
());
graph
->
RemoveNode
(
op
->
Node
());
}
}
fetch_ops
->
clear
();
fetch_ops
->
clear
();
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
3db9fad7
...
@@ -38,8 +38,7 @@ class SSAGraphExecutor {
...
@@ -38,8 +38,7 @@ class SSAGraphExecutor {
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
=
0
;
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
=
0
;
};
};
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
FetchOpHandle
*>*
fetch_ops
);
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>*
fetch_ops
);
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
3db9fad7
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
version_pair
.
get
()
);
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
version_pair
);
}
}
}
}
}
}
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
var
.
get
()
);
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
var
);
}
}
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
details
::
kGraphOps
))
{
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph_
))
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
.
get
()
);
ready_ops
.
insert
(
op
);
}
else
{
}
else
{
InsertPendingOp
(
&
pending_ops
,
op
.
get
()
);
InsertPendingOp
(
&
pending_ops
,
op
);
}
}
}
}
// Step 2. Insert FetchOps
// Step 2. Insert FetchOps
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>
>
fetch_ops
;
std
::
vector
<
FetchOpHandle
*
>
fetch_ops
;
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>
>
fetch_dependencies
;
std
::
unordered_set
<
VarHandleBase
*
>
fetch_dependencies
;
FeedFetchList
fetch_data
(
fetch_tensors
.
size
());
FeedFetchList
fetch_data
(
fetch_tensors
.
size
());
InsertFetchOps
(
fetch_tensors
,
&
fetch_ops
,
&
fetch_dependencies
,
&
pending_ops
,
InsertFetchOps
(
fetch_tensors
,
&
fetch_ops
,
&
fetch_dependencies
,
&
pending_ops
,
...
@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for
(
auto
&
run_op_future
:
run_op_futures_
)
{
for
(
auto
&
run_op_future
:
run_op_futures_
)
{
run_op_future
.
wait
();
run_op_future
.
wait
();
}
}
ClearFetchOp
(
graph_
.
get
(),
&
fetch_ops
);
exception_holder_
.
ReThrow
();
exception_holder_
.
ReThrow
();
}
else
{
}
else
{
continue
;
continue
;
...
@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void
ThreadedSSAGraphExecutor
::
InsertFetchOps
(
void
ThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>
>
*
fetch_ops
,
std
::
vector
<
FetchOpHandle
*
>
*
fetch_ops
,
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>
>
*
fetch_dependencies
,
std
::
unordered_set
<
VarHandleBase
*
>
*
fetch_dependencies
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
)
{
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
)
{
...
@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
...
@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
fetched_vars
[
fetch_var_name
].
push_back
(
*
it
->
second
.
rbegin
());
}
}
}
}
}
}
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
3db9fad7
...
@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
VarHandleBase
*
var
)
const
;
VarHandleBase
*
var
)
const
;
void
InsertFetchOps
(
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
vector
<
std
::
string
>
&
fetch_tensor
s
,
std
::
vector
<
FetchOpHandle
*>
*
fetch_op
s
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>
*
fetch_op
s
,
std
::
unordered_set
<
VarHandleBase
*>
*
fetch_dependencie
s
,
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
*
fetch_dependencie
s
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_op
s
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_op
s
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_var
s
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending
_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready
_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
);
FeedFetchList
*
fetch_data
);
private:
private:
ExecutionStrategy
strategy_
;
ExecutionStrategy
strategy_
;
...
...
paddle/fluid/framework/details/var_handle.cc
浏览文件 @
3db9fad7
...
@@ -20,6 +20,8 @@ namespace details {
...
@@ -20,6 +20,8 @@ namespace details {
VarHandleBase
::~
VarHandleBase
()
{}
VarHandleBase
::~
VarHandleBase
()
{}
VarHandle
::~
VarHandle
()
{
VLOG
(
4
)
<<
"deleting var handle "
<<
DebugString
();
}
std
::
string
VarHandle
::
DebugString
()
const
{
std
::
string
VarHandle
::
DebugString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
name_
<<
":"
<<
place_
;
ss
<<
name_
<<
":"
<<
place_
;
...
@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const {
...
@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const {
}
}
std
::
string
DummyVarHandle
::
DebugString
()
const
{
return
node_
->
Name
();
}
std
::
string
DummyVarHandle
::
DebugString
()
const
{
return
node_
->
Name
();
}
DummyVarHandle
::~
DummyVarHandle
()
{
VLOG
(
4
)
<<
"deleting dummy var handle "
<<
DebugString
();
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/var_handle.h
浏览文件 @
3db9fad7
...
@@ -35,7 +35,10 @@ class OpHandleBase;
...
@@ -35,7 +35,10 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e.
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
// This is a single assignment graph.
struct
VarHandleBase
{
struct
VarHandleBase
{
explicit
VarHandleBase
(
ir
::
Node
*
node
)
:
node_
(
node
)
{}
// Owned by `node`. No need to be deleted explicitly.
explicit
VarHandleBase
(
ir
::
Node
*
node
)
:
node_
(
node
)
{
node_
->
WrappedBy
(
this
);
}
virtual
~
VarHandleBase
();
virtual
~
VarHandleBase
();
...
@@ -94,6 +97,8 @@ struct VarHandleBase {
...
@@ -94,6 +97,8 @@ struct VarHandleBase {
struct
VarHandle
:
public
VarHandleBase
{
struct
VarHandle
:
public
VarHandleBase
{
explicit
VarHandle
(
ir
::
Node
*
node
)
:
VarHandleBase
(
node
)
{}
explicit
VarHandle
(
ir
::
Node
*
node
)
:
VarHandleBase
(
node
)
{}
virtual
~
VarHandle
();
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
VarHandle
(
ir
::
Node
*
node
,
size_t
version
,
size_t
scope_index
,
VarHandle
(
ir
::
Node
*
node
,
size_t
version
,
size_t
scope_index
,
...
@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase {
...
@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase {
struct
DummyVarHandle
:
public
VarHandleBase
{
struct
DummyVarHandle
:
public
VarHandleBase
{
explicit
DummyVarHandle
(
ir
::
Node
*
node
)
:
VarHandleBase
(
node
)
{}
explicit
DummyVarHandle
(
ir
::
Node
*
node
)
:
VarHandleBase
(
node
)
{}
virtual
~
DummyVarHandle
();
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
};
};
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
3db9fad7
...
@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
...
@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library
(
pass_builder SRCS pass_builder.cc DEPS pass
)
cc_library
(
pass_builder SRCS pass_builder.cc DEPS pass
)
cc_test
(
node_test SRCS node_test.cc DEPS node
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
3db9fad7
...
@@ -102,6 +102,15 @@ class Graph {
...
@@ -102,6 +102,15 @@ class Graph {
attr_dels_
[
attr_name
]
=
[]()
{};
attr_dels_
[
attr_name
]
=
[]()
{};
}
}
template
<
typename
AttrType
>
void
Erase
(
const
std
::
string
&
attr_name
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
!=
0
,
"%s not set in the graph"
,
attr_name
);
attr_dels_
[
attr_name
]();
attrs_
.
erase
(
attr_name
);
attr_dels_
.
erase
(
attr_name
);
}
const
std
::
unordered_set
<
ir
::
Node
*>
&
Nodes
()
const
{
return
node_set_
;
}
const
std
::
unordered_set
<
ir
::
Node
*>
&
Nodes
()
const
{
return
node_set_
;
}
// Create a normal variable with non-null VarDesc.
// Create a normal variable with non-null VarDesc.
...
...
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
3db9fad7
...
@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
...
@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
BuildOperationAdjList
(
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
BuildOperationAdjList
(
const
Graph
&
graph
);
const
Graph
&
graph
);
template
<
typename
T
>
std
::
vector
<
T
*>
FilterByNodeWrapper
(
const
Graph
&
graph
)
{
std
::
vector
<
T
*>
ret
;
for
(
ir
::
Node
*
n
:
graph
.
Nodes
())
{
if
(
n
->
IsWrappedBy
<
T
>
())
ret
.
push_back
(
&
n
->
Wrapper
<
T
>
());
}
return
ret
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/node.h
浏览文件 @
3db9fad7
...
@@ -15,7 +15,10 @@ limitations under the License. */
...
@@ -15,7 +15,10 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
...
@@ -24,9 +27,33 @@ namespace paddle {
...
@@ -24,9 +27,33 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
// Node should normally created by Graph::CreateXXXNode().
// Node should only created by Graph::CreateXXXNode().
// 1. Every Node should be part of a graph. No dangling Node exists.
// 2. Node only contains members necessary for building graph structure.
// It doesn't contain other unrelated members, such as device, etc.
//
// Sometimes, for specific usages, Node needs to have additional members,
// such as device_placement, version in order to be executed. It is suggested
// to use composition pattern.
//
// class RunnableOp {
// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); }
//
// int any_thing_;
// }
//
// RunnableOp is owned by the ir::Node that composes it. In other words.
// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node
// is deleted from the graph.
class
Node
{
class
Node
{
public:
public:
virtual
~
Node
()
{
if
(
!
wrapper_
.
empty
())
{
VLOG
(
4
)
<<
"ir::Node deleting a wrapper node "
<<
Name
();
wrapper_deleter_
();
}
}
enum
class
Type
{
kOperation
,
kVariable
};
enum
class
Type
{
kOperation
,
kVariable
};
#if !defined(_WIN32) // msvc not support constexpr correctly.
#if !defined(_WIN32) // msvc not support constexpr correctly.
static
constexpr
char
kControlDepVarName
[]
=
"__control_var"
;
static
constexpr
char
kControlDepVarName
[]
=
"__control_var"
;
...
@@ -48,6 +75,29 @@ class Node {
...
@@ -48,6 +75,29 @@ class Node {
return
op_desc_
.
get
();
return
op_desc_
.
get
();
}
}
// Set the `wrapper` that wraps the Node. `wrapper` is owned by Node.
template
<
typename
T
>
void
WrappedBy
(
T
*
wrapper
)
{
if
(
!
wrapper_
.
empty
())
{
wrapper_deleter_
();
}
wrapper_
=
wrapper
;
wrapper_deleter_
=
[
wrapper
]()
{
delete
wrapper
;
};
wrapper_type_
=
std
::
type_index
(
typeid
(
T
));
}
// Return a reference to the `wrapper`.
template
<
typename
T
>
T
&
Wrapper
()
{
return
*
boost
::
any_cast
<
T
*>
(
wrapper_
);
}
// Test if the Node is wrapped by type T.
template
<
typename
T
>
bool
IsWrappedBy
()
{
return
std
::
type_index
(
typeid
(
T
))
==
wrapper_type_
;
}
// Please don't use this API!
// Please don't use this API!
int
id
()
const
{
return
id_
;
}
int
id
()
const
{
return
id_
;
}
...
@@ -99,6 +149,11 @@ class Node {
...
@@ -99,6 +149,11 @@ class Node {
static
int
count_
;
static
int
count_
;
// Please don't use this API or make this public.
// Please don't use this API or make this public.
static
void
ResetId
()
{
count_
=
0
;
}
static
void
ResetId
()
{
count_
=
0
;
}
boost
::
any
wrapper_
;
std
::
function
<
void
(
void
)
>
wrapper_deleter_
;
std
::
type_index
wrapper_type_
=
std
::
type_index
(
typeid
(
void
));
DISABLE_COPY_AND_ASSIGN
(
Node
);
DISABLE_COPY_AND_ASSIGN
(
Node
);
};
};
...
...
paddle/fluid/framework/ir/node_test.cc
0 → 100644
浏览文件 @
3db9fad7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
RunnableOp
{
public:
RunnableOp
(
Node
*
node
,
bool
*
alive
)
:
node_
(
node
),
alive_
(
alive
)
{
node_
->
WrappedBy
(
this
);
}
virtual
~
RunnableOp
()
{
*
alive_
=
false
;
}
private:
Node
*
node_
;
bool
*
alive_
;
};
class
RunnableOp2
{
public:
RunnableOp2
(
Node
*
node
,
bool
*
alive
)
:
node_
(
node
),
alive_
(
alive
)
{
node_
->
WrappedBy
(
this
);
}
virtual
~
RunnableOp2
()
{
*
alive_
=
false
;
}
private:
Node
*
node_
;
bool
*
alive_
;
};
TEST
(
NodeTest
,
Basic
)
{
bool
alive1
=
true
;
bool
alive2
=
true
;
std
::
unique_ptr
<
Node
>
n1
(
CreateNodeForTest
(
"n1"
,
Node
::
Type
::
kVariable
));
std
::
unique_ptr
<
Node
>
n2
(
CreateNodeForTest
(
"n2"
,
Node
::
Type
::
kVariable
));
EXPECT_FALSE
(
n1
->
IsWrappedBy
<
RunnableOp
>
());
EXPECT_FALSE
(
n1
->
IsWrappedBy
<
RunnableOp2
>
());
EXPECT_FALSE
(
n2
->
IsWrappedBy
<
RunnableOp
>
());
EXPECT_FALSE
(
n2
->
IsWrappedBy
<
RunnableOp2
>
());
new
RunnableOp
(
n1
.
get
(),
&
alive1
);
new
RunnableOp2
(
n2
.
get
(),
&
alive2
);
EXPECT_TRUE
(
n1
->
IsWrappedBy
<
RunnableOp
>
());
EXPECT_FALSE
(
n1
->
IsWrappedBy
<
RunnableOp2
>
());
EXPECT_FALSE
(
n2
->
IsWrappedBy
<
RunnableOp
>
());
EXPECT_TRUE
(
n2
->
IsWrappedBy
<
RunnableOp2
>
());
EXPECT_TRUE
(
alive1
);
EXPECT_TRUE
(
alive2
);
n1
.
reset
(
nullptr
);
n2
.
reset
(
nullptr
);
EXPECT_FALSE
(
alive1
);
EXPECT_FALSE
(
alive2
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
3db9fad7
...
@@ -38,8 +38,8 @@ if(WITH_TESTING)
...
@@ -38,8 +38,8 @@ if(WITH_TESTING)
ARGS --word2vec_dirname=
${
WORD2VEC_MODEL_DIR
}
--book_dirname=
${
PYTHON_TESTS_DIR
}
/book
)
ARGS --word2vec_dirname=
${
WORD2VEC_MODEL_DIR
}
--book_dirname=
${
PYTHON_TESTS_DIR
}
/book
)
set_tests_properties
(
test_api_impl PROPERTIES DEPENDS test_image_classification
)
set_tests_properties
(
test_api_impl PROPERTIES DEPENDS test_image_classification
)
endif
()
endif
()
cc_test
(
test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor
${
inference_deps
}
paddle_inference_api
cc_test
(
test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor
${
inference_deps
}
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book
)
ARGS --dirname=
${
WORD2VEC_MODEL_DIR
}
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
cc_library
(
paddle_inference_tensorrt_subgraph_engine
cc_library
(
paddle_inference_tensorrt_subgraph_engine
...
...
paddle/fluid/inference/api/analysis_predictor_tester.cc
浏览文件 @
3db9fad7
...
@@ -24,7 +24,7 @@ using contrib::AnalysisConfig;
...
@@ -24,7 +24,7 @@ using contrib::AnalysisConfig;
TEST
(
AnalysisPredictor
,
ZeroCopy
)
{
TEST
(
AnalysisPredictor
,
ZeroCopy
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
config
.
model_dir
=
FLAGS_dirname
+
"/word2vec.inference.model"
;
config
.
model_dir
=
FLAGS_dirname
;
config
.
use_feed_fetch_ops
=
false
;
config
.
use_feed_fetch_ops
=
false
;
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
...
...
paddle/fluid/operators/space_to_depth_op.cc
0 → 100644
浏览文件 @
3db9fad7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/space_to_depth_op.h"
#include <string>
#include <vector>
namespace
paddle
{
namespace
operators
{
class
SpaceToDepthOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SpaceToDepthOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of SpaceToDepthOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
4
,
"input should be a 4D tensor"
);
auto
blocksize
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"blocksize"
);
PADDLE_ENFORCE_GT
(
blocksize
,
1
,
"The blocksize should be Greater than 1"
);
PADDLE_ENFORCE_GT
(
x_dims
[
1
],
0
,
"input channel should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
2
],
0
,
"input Height should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
3
],
0
,
"input Width should be Greater than 0"
);
PADDLE_ENFORCE_EQ
(
x_dims
[
1
]
%
(
blocksize
*
blocksize
),
0
,
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize"
);
PADDLE_ENFORCE_EQ
(
x_dims
[
2
]
%
(
blocksize
),
0
,
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize"
);
PADDLE_ENFORCE_EQ
(
x_dims
[
3
]
%
(
blocksize
),
0
,
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize"
);
VLOG
(
3
)
<<
"SpaceToDepthOp operator x.shape="
<<
x_dims
<<
"Attribute blocksize"
<<
blocksize
<<
std
::
endl
;
std
::
vector
<
int64_t
>
output_shape
(
4
,
0
);
// [B,C,H,W]
output_shape
[
0
]
=
x_dims
[
0
];
output_shape
[
1
]
=
x_dims
[
1
]
*
blocksize
*
blocksize
;
output_shape
[
2
]
=
x_dims
[
2
]
/
blocksize
;
output_shape
[
3
]
=
x_dims
[
3
]
/
blocksize
;
auto
out_dims
=
framework
::
make_ddim
(
output_shape
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
if
(
x_dims
[
0
]
==
out_dims
[
0
])
{
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
};
class
SpaceToDepthOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor). The input should be a 4D tensor B * C * W * H of "
"SpaceToDepthOp "
"operator."
);
AddOutput
(
"Out"
,
"(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of "
"SpaceToDepthOp operator."
);
AddAttr
<
int64_t
>
(
"blocksize"
,
"(int64_t, default 2) blocksize used to do change Space To Depth."
)
.
SetDefault
(
2
)
.
GreaterThan
(
1
);
AddComment
(
R"DOC(
reorg operator used in Yolo v2.
The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize,
Reshape Input(X) into the shape according to Attr(blocksize). The
data in Input(X) are unchanged.
Examples:
1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the blocksize is 2, the reorg operator will transform Input(X)
into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged.
)DOC"
);
}
};
class
SpaceToDepthGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
space_to_depth
,
ops
::
SpaceToDepthOp
,
ops
::
SpaceToDepthOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
space_to_depth_grad
,
ops
::
SpaceToDepthGradOp
);
REGISTER_OP_CPU_KERNEL
(
space_to_depth
,
ops
::
SpaceToDepthKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SpaceToDepthKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SpaceToDepthKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
space_to_depth_grad
,
ops
::
SpaceToDepthGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SpaceToDepthGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SpaceToDepthGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/space_to_depth_op.cu
0 → 100644
浏览文件 @
3db9fad7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/space_to_depth_op.h"
namespace
plat
=
paddle
::
platform
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
space_to_depth
,
ops
::
SpaceToDepthKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SpaceToDepthKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SpaceToDepthKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
space_to_depth_grad
,
ops
::
SpaceToDepthGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SpaceToDepthGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SpaceToDepthGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/space_to_depth_op.h
0 → 100644
浏览文件 @
3db9fad7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
space_to_depth_compute
{
public:
HOSTDEVICE
space_to_depth_compute
(
const
T
*
x
,
int64_t
w
,
int64_t
h
,
int64_t
c
,
int64_t
batch
,
int64_t
blocksize
,
int64_t
forward
,
T
*
out
)
:
x_
(
x
),
w_
(
w
),
h_
(
h
),
c_
(
c
),
batch_
(
batch
),
blocksize_
(
blocksize
),
forward_
(
forward
),
out_
(
out
)
{}
HOSTDEVICE
void
operator
()(
int64_t
in_index
)
{
int64_t
out_c
=
c_
/
(
blocksize_
*
blocksize_
);
// calculate each dim position with index of tensor
int64_t
b
=
in_index
/
(
c_
*
h_
*
w_
);
int64_t
k
=
(
in_index
%
(
c_
*
h_
*
w_
))
/
(
h_
*
w_
);
int64_t
j
=
((
in_index
%
(
c_
*
h_
*
w_
))
%
(
h_
*
w_
))
/
w_
;
int64_t
i
=
((
in_index
%
(
c_
*
h_
*
w_
))
%
(
h_
*
w_
))
%
w_
;
int64_t
c2
=
k
%
out_c
;
int64_t
offset
=
k
/
out_c
;
int64_t
w2
=
i
*
blocksize_
+
offset
%
blocksize_
;
int64_t
h2
=
j
*
blocksize_
+
offset
/
blocksize_
;
int64_t
out_index
=
w2
+
w_
*
blocksize_
*
(
h2
+
h_
*
blocksize_
*
(
c2
+
out_c
*
b
));
if
(
forward_
)
out_
[
out_index
]
=
x_
[
in_index
];
else
out_
[
in_index
]
=
x_
[
out_index
];
}
private:
const
T
*
x_
;
int64_t
w_
,
h_
,
c_
,
batch_
,
blocksize_
,
forward_
;
T
*
out_
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
SpaceToDepthKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
out
=
context
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
x
=
context
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
blocksize
=
context
.
Attr
<
int64_t
>
(
"blocksize"
);
auto
in_dims
=
x
->
dims
();
out
->
mutable_data
(
context
.
GetPlace
(),
x
->
type
());
auto
out_dims
=
out
->
dims
();
auto
B
=
in_dims
[
0
];
auto
C
=
in_dims
[
1
];
auto
H
=
in_dims
[
2
];
auto
W
=
in_dims
[
3
];
platform
::
ForRange
<
DeviceContext
>
for_range
(
context
.
template
device_context
<
DeviceContext
>(),
static_cast
<
size_t
>
(
x
->
numel
()));
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
out_data
=
out
->
data
<
T
>
();
paddle
::
operators
::
space_to_depth_compute
<
T
>
computer
(
x_data
,
W
,
H
,
C
,
B
,
blocksize
,
1
,
out_data
);
for_range
(
computer
);
out
->
Resize
(
out_dims
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SpaceToDepthGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
d_out
=
context
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
blocksize
=
context
.
Attr
<
int64_t
>
(
"blocksize"
);
auto
in_dims
=
d_x
->
dims
();
d_x
->
mutable_data
(
context
.
GetPlace
(),
d_out
->
type
());
auto
B
=
in_dims
[
0
];
auto
C
=
in_dims
[
1
];
auto
H
=
in_dims
[
2
];
auto
W
=
in_dims
[
3
];
platform
::
ForRange
<
DeviceContext
>
for_range
(
context
.
template
device_context
<
DeviceContext
>(),
static_cast
<
size_t
>
(
d_x
->
numel
()));
auto
*
dx_data
=
d_x
->
data
<
T
>
();
auto
*
dout_data
=
d_out
->
data
<
T
>
();
paddle
::
operators
::
space_to_depth_compute
<
T
>
computer
(
dout_data
,
W
,
H
,
C
,
B
,
blocksize
,
0
,
dx_data
);
for_range
(
computer
);
d_x
->
Resize
(
in_dims
);
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
3db9fad7
...
@@ -154,6 +154,7 @@ __all__ = [
...
@@ -154,6 +154,7 @@ __all__ = [
'mul'
,
'mul'
,
'sigmoid_cross_entropy_with_logits'
,
'sigmoid_cross_entropy_with_logits'
,
'maxout'
,
'maxout'
,
'space_to_depth'
,
'affine_grid'
,
'affine_grid'
,
'sequence_reverse'
,
'sequence_reverse'
,
'affine_channel'
,
'affine_channel'
,
...
@@ -7674,6 +7675,66 @@ def maxout(x, groups, name=None):
...
@@ -7674,6 +7675,66 @@ def maxout(x, groups, name=None):
return
out
return
out
def
space_to_depth
(
x
,
blocksize
,
name
=
None
):
"""
Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width]
This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
The attr blocksize indicates the input block size.
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]:
space_to_depth is used to This operation is useful for resizing the activations between convolutions
(but keeping all data)
- Non-overlapping blocks of size block_size x block size are rearranged into depth at each location.
- The depth of the output tensor is block_size * block_size * input channel
- The Y, X coordinates within each block of the input become the high order component of the output channel index
- channel should be divisible by square of blocksize
- height, width should be divsible by blocksize
Args:
x(variable): The input LoDtensor.
blocksize(variable): The blocksize to select the element on each feature map should be > 2
Returns:
Variable: The output LoDtensor.
Raises:
TypeError: blocksize type must be a long.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[1, 4, 2, 2], dtype='float32')
space_to_depthed = fluid.layers.space_to_depth(
x=data, blocksize=2)
"""
helper
=
LayerHelper
(
"space_to_depth"
,
**
locals
())
if
not
(
isinstance
(
blocksize
,
int
)):
raise
ValueError
(
"blocksize must be a python Int"
)
if
name
is
None
:
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
#fix create
else
:
out
=
helper
.
create_variable
(
name
=
name
,
dtype
=
x
.
dtype
,
persistable
=
False
)
helper
.
append_op
(
type
=
"space_to_depth"
,
inputs
=
{
"X"
:
x
},
attrs
=
{
"blocksize"
:
blocksize
},
outputs
=
{
"Out"
:
out
})
return
out
@
templatedoc
()
@
templatedoc
()
def
sequence_reverse
(
x
,
name
=
None
):
def
sequence_reverse
(
x
,
name
=
None
):
"""
"""
...
...
python/paddle/fluid/op.py
浏览文件 @
3db9fad7
...
@@ -108,6 +108,8 @@ class OpDescCreationMethod(object):
...
@@ -108,6 +108,8 @@ class OpDescCreationMethod(object):
new_attr
.
i
=
user_defined_attr
new_attr
.
i
=
user_defined_attr
elif
attr
.
type
==
framework_pb2
.
FLOAT
:
elif
attr
.
type
==
framework_pb2
.
FLOAT
:
new_attr
.
f
=
user_defined_attr
new_attr
.
f
=
user_defined_attr
elif
attr
.
type
==
framework_pb2
.
LONG
:
new_attr
.
l
=
user_defined_attr
elif
attr
.
type
==
framework_pb2
.
STRING
:
elif
attr
.
type
==
framework_pb2
.
STRING
:
new_attr
.
s
=
user_defined_attr
new_attr
.
s
=
user_defined_attr
elif
attr
.
type
==
framework_pb2
.
BOOLEAN
:
elif
attr
.
type
==
framework_pb2
.
BOOLEAN
:
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
3db9fad7
...
@@ -248,6 +248,17 @@ class TestBook(unittest.TestCase):
...
@@ -248,6 +248,17 @@ class TestBook(unittest.TestCase):
self
.
assertIsNotNone
(
layers
.
softmax
(
hid
))
self
.
assertIsNotNone
(
layers
.
softmax
(
hid
))
print
(
str
(
program
))
print
(
str
(
program
))
def
test_space_to_depth
(
self
):
program
=
Program
()
with
program_guard
(
program
):
data
=
layers
.
data
(
name
=
'data'
,
shape
=
[
32
,
9
,
6
,
6
],
append_batch_size
=
False
,
dtype
=
'float32'
)
self
.
assertIsNotNone
(
layers
.
space_to_depth
(
data
,
3
))
print
(
str
(
program
))
def
test_sequence_unsqueeze
(
self
):
def
test_sequence_unsqueeze
(
self
):
program
=
Program
()
program
=
Program
()
with
program_guard
(
program
):
with
program_guard
(
program
):
...
...
python/paddle/fluid/tests/unittests/test_space_to_depth_op.py
0 → 100644
浏览文件 @
3db9fad7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
op_test
import
OpTest
class
TestSpaceToDepthOp
(
OpTest
):
@
staticmethod
def
helper
(
in_
,
width
,
height
,
channel
,
batch
,
blocksize
,
forward
,
out_
):
channel_out
=
channel
//
(
blocksize
*
blocksize
)
for
b
in
range
(
batch
):
for
k
in
range
(
channel
):
for
j
in
range
(
height
):
for
i
in
range
(
width
):
in_index
=
i
+
width
*
(
j
+
height
*
(
k
+
channel
*
b
))
channel2
=
k
%
channel_out
offset
=
k
//
channel_out
width2
=
i
*
blocksize
+
offset
%
blocksize
height2
=
j
*
blocksize
+
offset
//
blocksize
out_index
=
width2
+
width
*
blocksize
*
(
height2
+
height
*
blocksize
*
(
channel2
+
channel_out
*
b
))
if
forward
:
out_
[
out_index
]
=
in_
[
in_index
]
else
:
out_
[
in_index
]
=
in_
[
out_index
]
def
setUp
(
self
):
self
.
init_data
()
self
.
op_type
=
"space_to_depth"
self
.
inputs
=
{
"X"
:
self
.
x
}
self
.
helper
(
self
.
x_1d
,
self
.
x
.
shape
[
3
],
self
.
x
.
shape
[
2
],
self
.
x
.
shape
[
1
],
self
.
x
.
shape
[
0
],
self
.
blocksize
,
self
.
forward
,
self
.
out_1d
)
self
.
out
=
np
.
reshape
(
self
.
out_1d
,
self
.
infered_shape
)
self
.
attrs
=
{
"blocksize"
:
self
.
blocksize
}
self
.
outputs
=
{
"Out"
:
self
.
out
}
def
init_data
(
self
):
self
.
ori_shape
=
(
32
,
12
,
6
,
6
)
self
.
infered_shape
=
(
32
,
48
,
3
,
3
)
self
.
one_d_len
=
32
*
48
*
3
*
3
self
.
blocksize
=
2
self
.
x
=
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
'float32'
)
self
.
x_1d
=
np
.
reshape
(
self
.
x
,
self
.
one_d_len
)
self
.
out
=
np
.
zeros
(
self
.
infered_shape
).
astype
(
'float32'
)
self
.
out_1d
=
np
.
reshape
(
self
.
out
,
self
.
one_d_len
)
self
.
forward
=
1
def
test_check_output
(
self
):
place
=
fluid
.
core
.
CUDAPlace
(
0
)
if
fluid
.
core
.
is_compiled_with_cuda
(
)
else
fluid
.
core
.
CPUPlace
()
self
.
check_output_with_place
(
place
,
1e-5
,
None
,
False
)
def
test_check_grad
(
self
):
place
=
fluid
.
core
.
CUDAPlace
(
0
)
if
fluid
.
core
.
is_compiled_with_cuda
(
)
else
fluid
.
core
.
CPUPlace
()
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
)
class
TestSpaceToDepthOpBasic
(
TestSpaceToDepthOp
):
def
init_data
(
self
):
self
.
ori_shape
=
(
32
,
8
,
6
,
6
)
self
.
infered_shape
=
(
32
,
32
,
3
,
3
)
self
.
one_d_len
=
32
*
32
*
3
*
3
self
.
blocksize
=
2
self
.
x
=
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
'float32'
)
self
.
x_1d
=
np
.
reshape
(
self
.
x
,
self
.
one_d_len
)
self
.
out
=
np
.
zeros
(
self
.
infered_shape
).
astype
(
'float32'
)
self
.
out_1d
=
np
.
reshape
(
self
.
out
,
self
.
one_d_len
)
self
.
forward
=
1
class
TestSpaceToDepthOpDoubleBasic
(
TestSpaceToDepthOp
):
def
init_data
(
self
):
self
.
ori_shape
=
(
32
,
8
,
6
,
6
)
self
.
infered_shape
=
(
32
,
32
,
3
,
3
)
self
.
one_d_len
=
32
*
32
*
3
*
3
self
.
blocksize
=
2
self
.
x
=
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
'float64'
)
self
.
x_1d
=
np
.
reshape
(
self
.
x
,
self
.
one_d_len
)
self
.
out
=
np
.
zeros
(
self
.
infered_shape
).
astype
(
'float64'
)
self
.
out_1d
=
np
.
reshape
(
self
.
out
,
self
.
one_d_len
)
self
.
forward
=
1
class
TestSpaceToDepthOpWithStride3
(
TestSpaceToDepthOp
):
def
init_data
(
self
):
self
.
ori_shape
=
(
32
,
9
,
6
,
6
)
self
.
infered_shape
=
(
32
,
81
,
2
,
2
)
self
.
one_d_len
=
32
*
81
*
2
*
2
self
.
blocksize
=
3
self
.
x
=
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
'float32'
)
self
.
x_1d
=
np
.
reshape
(
self
.
x
,
self
.
one_d_len
)
self
.
out
=
np
.
zeros
(
self
.
infered_shape
).
astype
(
'float32'
)
self
.
out_1d
=
np
.
reshape
(
self
.
out
,
self
.
one_d_len
)
self
.
forward
=
1
class
TestSpaceToDepthOpWithNotSquare
(
TestSpaceToDepthOp
):
def
init_data
(
self
):
self
.
ori_shape
=
(
32
,
9
,
9
,
6
)
self
.
infered_shape
=
(
32
,
81
,
3
,
2
)
self
.
one_d_len
=
32
*
81
*
3
*
2
self
.
blocksize
=
3
self
.
x
=
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
'float32'
)
self
.
x_1d
=
np
.
reshape
(
self
.
x
,
self
.
one_d_len
)
self
.
out
=
np
.
zeros
(
self
.
infered_shape
).
astype
(
'float32'
)
self
.
out_1d
=
np
.
reshape
(
self
.
out
,
self
.
one_d_len
)
self
.
forward
=
1
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录