Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ab72d28a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ab72d28a
编写于
7月 26, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean up and correctness check
上级
aa1085dd
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
184 addition
and
92 deletion
+184
-92
doc/fluid/design/ir/draft.md
doc/fluid/design/ir/draft.md
+9
-1
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+57
-46
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+2
-2
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+7
-6
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+4
-0
paddle/fluid/framework/details/ssa_graph_checker.cc
paddle/fluid/framework/details/ssa_graph_checker.cc
+8
-4
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+2
-2
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+3
-3
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+2
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+4
-4
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+6
-2
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+5
-3
paddle/fluid/framework/ir/graph_viz_pass.h
paddle/fluid/framework/ir/graph_viz_pass.h
+2
-2
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+16
-0
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+56
-14
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+1
-1
未找到文件。
doc/fluid/design/ir/draft.md
浏览文件 @
ab72d28a
...
...
@@ -75,7 +75,12 @@ can also fuse some `Graph`'s `Node`s.
class
Pass
{
public:
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
// Some correctness check.
auto
new_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// Some correctness check.
return
new_graph
;
}
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
...
...
@@ -89,6 +94,9 @@ class Pass {
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
);
protected:
virtual
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
};
// In my_pass.cc
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
ab72d28a
...
...
@@ -34,16 +34,22 @@ namespace paddle {
namespace
framework
{
namespace
details
{
static
const
char
kLossVarName
[]
=
"loss_var_name"
;
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kParams
[]
=
"params"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kStrategy
[]
=
"strategy"
;
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
loss_var_name_
=
Get
<
const
std
::
string
>
(
"loss_var_name"
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
);
strategy_
=
Get
<
const
BuildStrategy
>
(
"strategy"
);
loss_var_name_
=
Get
<
const
std
::
string
>
(
kLossVarName
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
);
strategy_
=
Get
<
const
BuildStrategy
>
(
kStrategy
);
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
);
#endif
for
(
auto
&
p
:
Get
<
const
std
::
unordered_set
<
std
::
string
>>
(
"params"
))
{
for
(
auto
&
p
:
Get
<
const
std
::
unordered_set
<
std
::
string
>>
(
kParams
))
{
grad_names_
.
insert
(
GradVarName
(
p
));
}
balance_vars_
.
resize
(
places_
.
size
(),
0
);
...
...
@@ -58,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
...
@@ -225,7 +231,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
return
sorted_ret
;
}
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
Init
();
// Give the topology sort order and rebuild the graph structure.
...
...
@@ -241,10 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
Set
(
"vars"
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
"dep_vars"
,
new
GraphDepVars
);
result
.
Set
(
"ops"
,
new
GraphOps
);
result
.
Set
(
"sharded_var_device"
,
new
ShardedVarDevice
);
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kShardedVarDevice
,
new
ShardedVarDevice
);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
...
...
@@ -281,7 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
graph
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
}
else
{
...
...
@@ -319,7 +325,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
g_name
,
cur_device_id
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
break
;
...
...
@@ -406,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
);
#endif
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
auto
*
in
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
i
).
at
(
p_name
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
i
).
at
(
p_name
);
auto
*
out_var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
p_name
,
p
);
...
...
@@ -427,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
...
...
@@ -436,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
...
...
@@ -465,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
d_name
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
auto
var
=
new
VarHandle
(
...
...
@@ -524,7 +530,7 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
)
const
{
auto
&
sharded_var_device
=
graph
.
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
);
auto
&
sharded_var_device
=
graph
.
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
);
auto
got
=
sharded_var_device
.
find
(
varname
);
return
got
==
sharded_var_device
.
end
()
?
-
1
:
got
->
second
;
}
...
...
@@ -544,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
communication_dev_ctx
);
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
...
...
@@ -565,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
));
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
}
...
...
@@ -575,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
}
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
dst_dev_id
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
dst_dev_id
][
og
];
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
...
...
@@ -606,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
// on it.
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateControlDepVar
());
prev_op
->
AddOutput
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
}
}
...
...
@@ -635,18 +641,18 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
for
(
auto
&
varname
:
output_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
for
(
auto
&
varname
:
output_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
{
...
...
@@ -661,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"fetch_barrier"
);
}
}
...
...
@@ -687,7 +693,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
...
...
@@ -698,7 +704,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
for
(
auto
&
varname
:
output_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
{
...
...
@@ -709,17 +715,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
node
->
Op
()
->
Type
());
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
RPCOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
RPCOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
*
node
->
Op
(),
local_scopes_
[
op_dev_id
],
node
->
Op
()
->
Type
(),
places_
[
op_dev_id
]));
if
(
node
->
Op
()
->
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"send"
);
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"send"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"send_barrier"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"recv"
);
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"recv"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// do nothing
}
else
{
...
...
@@ -743,4 +749,9 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
}
// namespace paddle
REGISTER_PASS
(
multi_device_pass
,
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
);
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLossVarName
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kParams
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
);
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
ab72d28a
...
...
@@ -31,8 +31,8 @@ class Scope;
namespace
details
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
p
ublic
:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
p
rotected
:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
private:
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
ab72d28a
...
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
framework
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
...
...
@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
}
}
...
...
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
...
...
@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
new_node
->
Name
()];
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
...
...
@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
ab72d28a
...
...
@@ -39,15 +39,19 @@ namespace details {
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
GraphVars
;
const
char
kGraphVars
[]
=
"vars"
;
// aux variables to represent dependency. Useful to resolve data hazard.
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
GraphDepVars
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
typedef
std
::
unordered_map
<
std
::
string
,
int
>
ShardedVarDevice
;
const
char
kShardedVarDevice
[]
=
"sharded_var_device"
;
class
SSAGraphBuilder
:
public
ir
::
Pass
{
public:
...
...
paddle/fluid/framework/details/ssa_graph_checker.cc
浏览文件 @
ab72d28a
...
...
@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
};
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
insert_pending_var
(
version_pair
.
get
());
...
...
@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
insert_pending_var
(
var
.
get
());
}
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
ready_ops
.
insert
(
op
.
get
());
}
else
{
...
...
@@ -87,4 +87,8 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
// namespace paddle
REGISTER_PASS
(
multi_device_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
);
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphOps
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kShardedVarDevice
);
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
ab72d28a
...
...
@@ -23,8 +23,8 @@ namespace framework {
namespace
details
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
p
ublic
:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
p
rotected
:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
PADDLE_ENFORCE
(
IsValidGraph
(
graph
.
get
()));
return
graph
;
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
浏览文件 @
ab72d28a
...
...
@@ -22,7 +22,7 @@ namespace details {
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
...
...
@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
callback
(
*
var
);
}
}
...
...
@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
kGraphOps
))
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
...
...
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
ab72d28a
...
...
@@ -36,8 +36,8 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
p
ublic
:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
p
rotected
:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
const
std
::
string
>
(
"debug_graphviz_path"
)));
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
ab72d28a
...
...
@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
// Transform SSAGraph to pending_ops & pending_vars
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
version_pair
.
get
());
}
}
}
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
var
.
get
());
}
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
details
::
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
.
get
());
}
else
{
...
...
@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
ab72d28a
...
...
@@ -40,10 +40,14 @@ class Graph {
attr_dels_
.
clear
();
}
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
return
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
();
}
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
()
,
"%s attr not registered for graph."
,
attr_name
);
PADDLE_ENFORCE
(
Has
(
attr_name
),
"%s attr not registered for graph."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
ab72d28a
...
...
@@ -20,10 +20,11 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
"graph_viz_path"
);
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphVizPath
);
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
PADDLE_ENFORCE
(
fout
->
good
());
std
::
ostream
&
sout
=
*
fout
;
...
...
@@ -67,4 +68,5 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
);
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraphVizPath
);
paddle/fluid/framework/ir/graph_viz_pass.h
浏览文件 @
ab72d28a
...
...
@@ -28,8 +28,8 @@ namespace framework {
namespace
ir
{
class
GraphVizPass
:
public
Pass
{
p
ublic
:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
p
rotected
:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
ab72d28a
...
...
@@ -17,6 +17,22 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
namespace
ir
{
std
::
unique_ptr
<
Graph
>
Pass
::
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
"Required pass atrribute %s not registered."
,
attr
);
}
for
(
const
std
::
string
&
attr
:
required_graph_attrs_
)
{
PADDLE_ENFORCE
(
graph
->
Has
(
attr
),
"Required graph atrribute %s not exist."
,
attr
);
}
auto
applied_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE
(
!
HasCircle
(
*
applied_graph
),
"Illegal Pass. Generated graph shouldn't has cycle."
);
return
applied_graph
;
}
PassRegistry
&
PassRegistry
::
Instance
()
{
static
PassRegistry
g_pass_info_map
;
return
g_pass_info_map
;
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
ab72d28a
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
...
...
@@ -26,6 +27,8 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
namespace
ir
{
template
<
typename
PassType
>
struct
PassRegistrar
;
class
Pass
{
public:
...
...
@@ -40,7 +43,7 @@ class Pass {
attr_dels_
.
clear
();
}
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
;
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
...
...
@@ -69,7 +72,25 @@ class Pass {
attrs_
[
attr_name
]
=
attr
;
}
protected:
virtual
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
private:
template
<
typename
PassType
>
friend
struct
PassRegistrar
;
void
RegisterRequiredPassAttrs
(
const
std
::
unordered_set
<
std
::
string
>
&
attrs
)
{
required_pass_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
void
RegisterRequiredGraphAttrs
(
const
std
::
unordered_set
<
std
::
string
>
&
attrs
)
{
required_graph_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
};
...
...
@@ -119,10 +140,28 @@ struct PassRegistrar : public Registrar {
explicit
PassRegistrar
(
const
char
*
pass_type
)
{
PADDLE_ENFORCE
(
!
PassRegistry
::
Instance
().
Has
(
pass_type
),
"'%s' is registered more than once."
,
pass_type
);
PassRegistry
::
Instance
().
Insert
(
pass_type
,
[]()
->
std
::
unique_ptr
<
Pass
>
{
return
std
::
unique_ptr
<
Pass
>
(
new
PassType
());
PassRegistry
::
Instance
().
Insert
(
pass_type
,
[
this
]()
->
std
::
unique_ptr
<
Pass
>
{
std
::
unique_ptr
<
Pass
>
pass
(
new
PassType
());
pass
->
RegisterRequiredPassAttrs
(
this
->
required_pass_attrs_
);
pass
->
RegisterRequiredGraphAttrs
(
this
->
required_graph_attrs_
);
return
pass
;
});
}
PassRegistrar
<
PassType
>
&
RequirePassAttr
(
const
std
::
string
&
attr
)
{
required_pass_attrs_
.
insert
(
attr
);
return
*
this
;
}
PassRegistrar
<
PassType
>
&
RequireGraphAttr
(
const
std
::
string
&
attr
)
{
required_graph_attrs_
.
insert
(
attr
);
return
*
this
;
}
private:
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
...
...
@@ -141,7 +180,10 @@ struct PassRegistrar : public Registrar {
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
}
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
ab72d28a
...
...
@@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices(
if
(
member_
->
executor_
)
{
auto
&
sharded_var_device
=
member_
->
executor_
->
Graph
().
Get
<
details
::
ShardedVarDevice
>
(
"sharded_var_device"
);
details
::
kShardedVarDevice
);
if
(
sharded_var_device
.
find
(
var
)
!=
sharded_var_device
.
end
())
{
var_dev_id
=
sharded_var_device
.
at
(
var
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录