Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b123e43b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b123e43b
编写于
3月 24, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
extract multi devices graph builder
上级
dd73d18b
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
354 addition
and
242 deletion
+354
-242
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+2
-7
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+3
-0
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+140
-0
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+46
-0
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+88
-0
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+56
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+19
-235
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
b123e43b
...
...
@@ -88,14 +88,9 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method
)
if
(
WITH_GPU
)
set
(
parallel_executor_cuda_deps nccl_all_reduce_op_handle
)
else
()
set
(
parallel_executor_cuda_deps
)
endif
()
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle
fetch_op_handle computation_op_handle ssa_graph
${
parallel_executor_cuda_deps
}
)
backward glog lod_rank_table simple_threadpool multi_devices_graph_builder fetch_op_handle
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
b123e43b
...
...
@@ -7,3 +7,6 @@ nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_h
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
nccl_all_reduce_op_handle scale_loss_grad_op_handle
)
paddle/fluid/framework/details/multi_devices_graph_builder.cc
0 → 100644
浏览文件 @
b123e43b
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
nccl_ctxs_
(
nccl_ctxs
)
{
for
(
auto
&
p
:
params
)
{
grad_names_
.
insert
(
GradVarName
(
p
));
}
}
void
MultiDevSSAGraphBuilder
::
Build
(
const
ProgramDesc
&
program
,
SSAGraph
*
graph
)
const
{
SSAGraph
&
result
=
*
graph
;
result
.
vars_
.
resize
(
places_
.
size
());
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
bool
change_forward
=
false
;
if
(
!
is_forwarding
)
{
// FIXME(yy): Do not hard code like this
if
(
op
->
OutputArgumentNames
().
size
()
==
1
&&
op
->
OutputArgumentNames
()[
0
]
==
GradVarName
(
loss_var_name_
))
{
continue
;
// Drop fill 1. for backward coeff;
}
}
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
*
s
=
local_scopes_
[
i
];
result
.
ops_
.
emplace_back
(
new
ComputationOpHandle
(
*
op
,
s
,
p
));
auto
*
op_handle
=
result
.
ops_
.
back
().
get
();
op_handle
->
dev_ctx_
[
p
]
=
const_cast
<
platform
::
DeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
auto
var_names
=
op
->
InputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
&
result
,
each_var_name
,
p
,
i
);
op_handle
->
AddInput
(
var
);
}
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
CreateOpOutput
(
&
result
,
op_handle
,
each_var_name
,
p
,
i
);
}
if
(
is_forwarding
)
{
if
(
var_names
.
size
()
==
1
&&
var_names
[
0
]
==
loss_var_name_
)
{
// Insert ScaleCost OpHandle
op_handle
=
new
ScaleLossGradOpHandle
(
local_scopes_
.
size
(),
s
,
p
,
nccl_ctxs_
->
DevCtx
(
p
));
result
.
ops_
.
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput
(
&
result
,
op_handle
,
GradVarName
(
loss_var_name_
),
p
,
i
);
change_forward
=
true
;
}
}
}
if
(
change_forward
)
{
is_forwarding
=
false
;
}
if
(
!
is_forwarding
)
{
auto
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
og
:
var_names
)
{
if
(
grad_names_
.
count
(
og
)
!=
0
)
{
// is param grad
// Insert NCCL AllReduce Op
result
.
ops_
.
emplace_back
(
new
NCCLAllReduceOpHandle
(
local_scopes_
,
places_
,
*
nccl_ctxs_
));
auto
*
op_handle
=
result
.
ops_
.
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
vars
=
result
.
vars_
[
i
][
og
];
if
(
vars
.
empty
())
{
// This device has no data. continue.
continue
;
}
auto
*
prev_grad
=
&
vars
[
vars
.
size
()
-
1
];
op_handle
->
AddInput
(
prev_grad
);
auto
&
var
=
vars
[
vars
.
size
()];
var
.
place_
=
p
;
var
.
name_
=
og
;
var
.
version_
=
vars
.
size
()
-
1
;
op_handle
->
AddOutput
(
&
var
);
}
}
}
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards
(
&
result
);
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/multi_devices_graph_builder.h
0 → 100644
浏览文件 @
b123e43b
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace
paddle
{
namespace
platform
{
class
NCCLContextMap
;
}
namespace
framework
{
class
Scope
;
namespace
details
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
public:
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
);
void
Build
(
const
ProgramDesc
&
program
,
SSAGraph
*
graph
)
const
override
;
private:
std
::
string
loss_var_name_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
platform
::
NCCLContextMap
*
nccl_ctxs_
;
std
::
unordered_set
<
std
::
string
>
grad_names_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_builder.cc
0 → 100644
浏览文件 @
b123e43b
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
SSAGraph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
vars_
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
return
;
}
auto
it_new
=
name_pair
.
second
.
rbegin
();
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
auto
*
write_op
=
it_new
->
second
.
generated_op_
;
auto
&
read_ops
=
it_old
->
second
.
pending_ops_
;
auto
*
ex_write_op
=
it_old
->
second
.
generated_op_
;
if
(
ex_write_op
==
nullptr
)
{
// Nobody write this var.
continue
;
}
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
auto
*
dep_var
=
new
DummyVarHandle
();
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
dep_vars_
.
emplace
(
dep_var
);
}
}
}
}
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
SSAGraph
*
graph
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
vars_
[
place_offset
];
auto
&
var_holder
=
var_holders
[
each_var_name
];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
auto
&
init_var
=
var_holder
[
0
];
init_var
.
place_
=
place
;
init_var
.
name_
=
each_var_name
;
init_var
.
generated_op_
=
nullptr
;
init_var
.
version_
=
0
;
var
=
&
init_var
;
}
else
{
var
=
&
var_holder
.
rbegin
()
->
second
;
}
return
var
;
}
void
SSAGraphBuilder
::
CreateOpOutput
(
SSAGraph
*
graph
,
OpHandleBase
*
op_handle
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
vars_
[
place_offset
][
each_var_name
];
size_t
version
=
vars
.
size
();
auto
&
var
=
vars
[
version
];
var
.
version_
=
version
;
var
.
name_
=
each_var_name
;
var
.
place_
=
place
;
op_handle
->
AddOutput
(
&
var
);
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_builder.h
0 → 100644
浏览文件 @
b123e43b
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include <string>
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
SSAGraphBuilder
{
public:
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
void
Build
(
const
ProgramDesc
&
program
,
SSAGraph
*
graph
)
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static
void
PolishGraphToSupportDataHazards
(
SSAGraph
*
graph
);
static
VarHandle
*
CreateOrGetLatestVarHandle
(
SSAGraph
*
graph
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
static
void
CreateOpOutput
(
SSAGraph
*
graph
,
OpHandleBase
*
op_handle
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
b123e43b
...
...
@@ -16,231 +16,14 @@ limitations under the License. */
#include "ThreadPool.h"
#include "lod_tensor.h"
#include "op_registry.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace
paddle
{
namespace
framework
{
using
details
::
ComputationOpHandle
;
using
details
::
DummyVarHandle
;
using
details
::
FetchOpHandle
;
using
details
::
NCCLAllReduceOpHandle
;
using
details
::
OpHandleBase
;
using
details
::
ScaleLossGradOpHandle
;
using
details
::
SSAGraph
;
using
details
::
VarHandle
;
using
details
::
VarHandleBase
;
class
SSAGraphBuilder
{
public:
virtual
~
SSAGraphBuilder
()
{}
virtual
void
Build
(
const
ProgramDesc
&
program
,
SSAGraph
*
graph
)
const
=
0
;
protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static
void
PolishGraphToSupportDataHazards
(
SSAGraph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
vars_
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
return
;
}
auto
it_new
=
name_pair
.
second
.
rbegin
();
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
auto
*
write_op
=
it_new
->
second
.
generated_op_
;
auto
&
read_ops
=
it_old
->
second
.
pending_ops_
;
auto
*
ex_write_op
=
it_old
->
second
.
generated_op_
;
if
(
ex_write_op
==
nullptr
)
{
// Nobody write this var.
continue
;
}
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
auto
*
dep_var
=
new
DummyVarHandle
();
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
dep_vars_
.
emplace
(
dep_var
);
}
}
}
}
}
static
VarHandle
*
CreateOrGetLatestVarHandle
(
SSAGraph
*
graph
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
vars_
[
place_offset
];
auto
&
var_holder
=
var_holders
[
each_var_name
];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
auto
&
init_var
=
var_holder
[
0
];
init_var
.
place_
=
place
;
init_var
.
name_
=
each_var_name
;
init_var
.
generated_op_
=
nullptr
;
init_var
.
version_
=
0
;
var
=
&
init_var
;
}
else
{
var
=
&
var_holder
.
rbegin
()
->
second
;
}
return
var
;
}
static
void
CreateOpOutput
(
SSAGraph
*
graph
,
OpHandleBase
*
op_handle
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
vars_
[
place_offset
][
each_var_name
];
size_t
version
=
vars
.
size
();
auto
&
var
=
vars
[
version
];
var
.
version_
=
version
;
var
.
name_
=
each_var_name
;
var
.
place_
=
place
;
op_handle
->
AddOutput
(
&
var
);
}
};
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
public:
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
nccl_ctxs_
(
nccl_ctxs
)
{
for
(
auto
&
p
:
params
)
{
grad_names_
.
insert
(
GradVarName
(
p
));
}
}
void
Build
(
const
ProgramDesc
&
program
,
SSAGraph
*
graph
)
const
override
{
SSAGraph
&
result
=
*
graph
;
result
.
vars_
.
resize
(
places_
.
size
());
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
bool
change_forward
=
false
;
if
(
!
is_forwarding
)
{
// FIXME(yy): Do not hard code like this
if
(
op
->
OutputArgumentNames
().
size
()
==
1
&&
op
->
OutputArgumentNames
()[
0
]
==
GradVarName
(
loss_var_name_
))
{
continue
;
// Drop fill 1. for backward coeff;
}
}
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
*
s
=
local_scopes_
[
i
];
result
.
ops_
.
emplace_back
(
new
ComputationOpHandle
(
*
op
,
s
,
p
));
auto
*
op_handle
=
result
.
ops_
.
back
().
get
();
op_handle
->
dev_ctx_
[
p
]
=
const_cast
<
platform
::
DeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
auto
var_names
=
op
->
InputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
&
result
,
each_var_name
,
p
,
i
);
op_handle
->
AddInput
(
var
);
}
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
each_var_name
:
var_names
)
{
CreateOpOutput
(
&
result
,
op_handle
,
each_var_name
,
p
,
i
);
}
if
(
is_forwarding
)
{
if
(
var_names
.
size
()
==
1
&&
var_names
[
0
]
==
loss_var_name_
)
{
// Insert ScaleCost OpHandle
op_handle
=
new
ScaleLossGradOpHandle
(
local_scopes_
.
size
(),
s
,
p
,
nccl_ctxs_
->
DevCtx
(
p
));
result
.
ops_
.
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput
(
&
result
,
op_handle
,
GradVarName
(
loss_var_name_
),
p
,
i
);
change_forward
=
true
;
}
}
}
if
(
change_forward
)
{
is_forwarding
=
false
;
}
if
(
!
is_forwarding
)
{
auto
var_names
=
op
->
OutputArgumentNames
();
for
(
auto
&
og
:
var_names
)
{
if
(
grad_names_
.
count
(
og
)
!=
0
)
{
// is param grad
// Insert NCCL AllReduce Op
result
.
ops_
.
emplace_back
(
new
NCCLAllReduceOpHandle
(
local_scopes_
,
places_
,
*
nccl_ctxs_
));
auto
*
op_handle
=
result
.
ops_
.
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
vars
=
result
.
vars_
[
i
][
og
];
if
(
vars
.
empty
())
{
// This device has no data. continue.
continue
;
}
auto
*
prev_grad
=
&
vars
[
vars
.
size
()
-
1
];
op_handle
->
AddInput
(
prev_grad
);
auto
&
var
=
vars
[
vars
.
size
()];
var
.
place_
=
p
;
var
.
name_
=
og
;
var
.
version_
=
vars
.
size
()
-
1
;
op_handle
->
AddOutput
(
&
var
);
}
}
}
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards
(
&
result
);
}
private:
std
::
string
loss_var_name_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
platform
::
NCCLContextMap
*
nccl_ctxs_
;
std
::
unordered_set
<
std
::
string
>
grad_names_
;
};
class
ParallelExecutorPrivate
{
public:
explicit
ParallelExecutorPrivate
(
size_t
num_threads
,
...
...
@@ -256,17 +39,17 @@ class ParallelExecutorPrivate {
std
::
unique_ptr
<
platform
::
NCCLContextMap
>
nccl_ctxs_
;
SSAGraph
graph_
;
details
::
SSAGraph
graph_
;
// Use a simpler thread pool, might be faster.
std
::
unique_ptr
<
ThreadPool
>
pool_
;
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
void
RunOp
(
bool
use_event
,
std
::
unordered_map
<
VarHandleBase
*
,
std
::
atomic
<
bool
>>
&
pending_vars
,
OpHandleBase
*
op
)
{
void
RunOp
(
bool
use_event
,
std
::
unordered_map
<
details
::
VarHandleBase
*
,
std
::
atomic
<
bool
>>
&
pending_vars
,
details
::
OpHandleBase
*
op
)
{
std
::
vector
<
std
::
atomic
<
bool
>
*>
*
ready_buffer
=
new
std
::
vector
<
std
::
atomic
<
bool
>
*>
();
for
(
auto
*
var
:
op
->
outputs_
)
{
...
...
@@ -321,8 +104,8 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
details
::
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
nccl_ctxs_
.
get
());
builder
.
Build
(
main_program
,
&
member_
->
graph_
);
...
...
@@ -389,9 +172,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
FeedFetchList
fetched_data
(
fetch_tensors
.
size
());
// Version --> VarHandle
member_
->
exception_
.
reset
();
std
::
unordered_map
<
VarHandleBase
*
,
std
::
atomic
<
bool
>>
pending_vars
;
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
vector
<
DummyVarHandle
>
dummy_vars
;
std
::
unordered_map
<
details
::
VarHandleBase
*
,
std
::
atomic
<
bool
>>
pending_vars
;
std
::
unordered_map
<
details
::
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
vector
<
details
::
DummyVarHandle
>
dummy_vars
;
for
(
auto
&
var_map
:
member_
->
graph_
.
vars_
)
{
for
(
auto
&
name_pair
:
var_map
)
{
...
...
@@ -406,7 +189,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
pending_vars
[
var
.
get
()]
=
var
->
generated_op_
==
nullptr
;
}
std
::
vector
<
OpHandleBase
*>
to_run
;
std
::
vector
<
details
::
OpHandleBase
*>
to_run
;
for
(
auto
&
op
:
member_
->
graph_
.
ops_
)
{
if
(
op
->
inputs_
.
empty
())
{
// Special case, Op has no input.
...
...
@@ -416,7 +199,8 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
}
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
details
::
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
member_
->
graph_
.
vars_
)
{
...
...
@@ -427,13 +211,13 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
}
std
::
vector
<
FetchOpHandle
>
fetch_ops
;
std
::
vector
<
details
::
FetchOpHandle
>
fetch_ops
;
for
(
size_t
i
=
0
;
i
<
fetch_tensors
.
size
();
++
i
)
{
auto
&
var_name
=
fetch_tensors
[
i
];
auto
&
vars
=
fetched_vars
[
var_name
];
fetch_ops
.
emplace_back
(
&
fetched_data
,
i
,
&
member_
->
local_scopes_
);
FetchOpHandle
*
op
=
&
fetch_ops
.
back
();
details
::
FetchOpHandle
*
op
=
&
fetch_ops
.
back
();
// FIXME: Use new device context
for
(
auto
&
p
:
member_
->
places_
)
{
...
...
@@ -457,7 +241,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
while
(
!
pending_vars
.
empty
())
{
VarHandleBase
*
ready_var
=
nullptr
;
details
::
VarHandleBase
*
ready_var
=
nullptr
;
for
(
auto
&
pair
:
pending_vars
)
{
if
(
pair
.
second
.
load
(
std
::
memory_order_acquire
))
{
ready_var
=
pair
.
first
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录