Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c50c22b0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c50c22b0
编写于
1月 10, 2022
作者:
L
LiYuRio
提交者:
GitHub
1月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fleet Executor] Modified python cache strategy to support multi carriers (#38839)
上级
ededcda2
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
103 addition
and
96 deletion
+103
-96
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+47
-7
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+5
-3
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+8
-47
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+4
-7
paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto
...luid/distributed/fleet_executor/fleet_executor_desc.proto
+0
-1
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+39
-31
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
c50c22b0
...
@@ -19,7 +19,9 @@
...
@@ -19,7 +19,9 @@
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -43,18 +45,24 @@ void Carrier::Init(
...
@@ -43,18 +45,24 @@ void Carrier::Init(
int64_t
rank
,
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
rank_
=
rank
;
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
minibatch_scope_
=
minibatch_scope
;
microbatch_scopes_
=
microbatch_scopes
;
place_
=
place
;
place_
=
place
;
root_scope_
=
root_
scope
;
root_scope_
=
scope
;
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
platform
::
errors
::
InvalidArgument
(
"root_scope can not be nullptr"
));
minibatch_scope_
=
&
root_scope_
->
NewScope
();
microbatch_scopes_
.
resize
(
num_micro_batches
);
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
CopyParameters
(
i
,
program
);
}
// TODO(fleet_exe dev): thread pool
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
SetThreadNum
(
thread_num_
);
...
@@ -64,10 +72,33 @@ void Carrier::Init(
...
@@ -64,10 +72,33 @@ void Carrier::Init(
is_init_
=
true
;
is_init_
=
true
;
}
}
void
Carrier
::
Release
()
{}
void
Carrier
::
Release
()
{
if
(
root_scope_
)
{
root_scope_
->
DropKids
();
}
}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
void
Carrier
::
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
)
{
auto
&
global_block
=
program
.
Block
(
0
);
for
(
auto
&
var
:
global_block
.
AllVars
())
{
if
(
var
->
Persistable
()
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
5
)
<<
"Create persistable var: "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
else
if
(
!
var
->
Persistable
())
{
auto
*
ptr
=
microbatch_scopes_
[
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
5
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for microbatch "
<<
microbatch_id
<<
", which pointer is "
<<
ptr
<<
"."
;
InitializeVariable
(
ptr
,
var
->
GetType
());
}
}
}
bool
Carrier
::
EnqueueInterceptorMessage
(
bool
Carrier
::
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
const
InterceptorMessage
&
interceptor_message
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
...
@@ -116,6 +147,15 @@ void Carrier::Start() {
...
@@ -116,6 +147,15 @@ void Carrier::Start() {
// TODO(wangxi): async step
// TODO(wangxi): async step
Wait
();
Wait
();
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
for
(
auto
*
micro_scope
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scope
->
DropKids
();
}
}
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
c50c22b0
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
Scope
;
class
Scope
;
class
ProgramDesc
;
}
}
namespace
distributed
{
namespace
distributed
{
...
@@ -55,9 +56,10 @@ class Carrier final {
...
@@ -55,9 +56,10 @@ class Carrier final {
int64_t
rank
,
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
);
const
platform
::
Place
&
place
);
void
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
);
void
Release
();
void
Release
();
void
Wait
();
void
Wait
();
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
c50c22b0
...
@@ -22,8 +22,6 @@
...
@@ -22,8 +22,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -38,7 +36,6 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
...
@@ -38,7 +36,6 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
}
}
FleetExecutor
::~
FleetExecutor
()
{
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
for
(
const
auto
&
carrier_id
:
carrier_ids_
)
{
for
(
const
auto
&
carrier_id
:
carrier_ids_
)
{
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
Release
();
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
Release
();
}
}
...
@@ -47,7 +44,7 @@ FleetExecutor::~FleetExecutor() {
...
@@ -47,7 +44,7 @@ FleetExecutor::~FleetExecutor() {
void
FleetExecutor
::
Init
(
void
FleetExecutor
::
Init
(
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
)
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
)
{
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
0
,
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -72,31 +69,23 @@ void FleetExecutor::Init(
...
@@ -72,31 +69,23 @@ void FleetExecutor::Init(
for
(
auto
&
unique_op
:
ops
)
{
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
unique_op
.
release
();
}
}
root_scope_
=
scope
;
place_
=
place
;
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
platform
::
errors
::
InvalidArgument
(
"root_scope_ can not be nullptr"
));
minibatch_scope_
=
&
root_scope_
->
NewScope
();
int64_t
num_micro_batches
=
exe_desc_
.
num_micro_batches
();
microbatch_scopes_
.
resize
(
num_micro_batches
);
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
CopyParameters
(
i
,
program_desc
);
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier_ids_
.
insert
(
carrier_id
);
carrier_ids_
.
insert
(
carrier_id
);
// Set current running carrier
// Set current running carrier
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
InitCarrier
(
carrier
);
InitCarrier
(
carrier
,
scope
,
place
,
num_micro_batches
,
program_desc
);
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
}
void
FleetExecutor
::
InitCarrier
(
Carrier
*
carrier
)
{
void
FleetExecutor
::
InitCarrier
(
Carrier
*
carrier
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
)
{
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_node
(),
root_scope_
,
runtime_graph_
->
interceptor_id_to_node
(),
program_desc
,
scope
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
num_micro_batches
,
place
);
}
}
void
FleetExecutor
::
InitMessageBus
()
{
void
FleetExecutor
::
InitMessageBus
()
{
...
@@ -140,34 +129,6 @@ void FleetExecutor::Run(const std::string& carrier_id) {
...
@@ -140,34 +129,6 @@ void FleetExecutor::Run(const std::string& carrier_id) {
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
}
carrier
->
Start
();
carrier
->
Start
();
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scop
->
DropKids
();
}
}
void
FleetExecutor
::
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
)
{
auto
&
global_block
=
program
.
Block
(
0
);
for
(
auto
&
var
:
global_block
.
AllVars
())
{
if
(
var
->
Persistable
()
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
5
)
<<
"Create persistable var: "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
else
if
(
!
var
->
Persistable
())
{
auto
*
ptr
=
microbatch_scopes_
[
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
5
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for microbatch "
<<
microbatch_id
<<
", which pointer is "
<<
ptr
<<
"."
;
InitializeVariable
(
ptr
,
var
->
GetType
());
}
}
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
c50c22b0
...
@@ -39,7 +39,7 @@ class FleetExecutor final {
...
@@ -39,7 +39,7 @@ class FleetExecutor final {
~
FleetExecutor
();
~
FleetExecutor
();
void
Init
(
const
std
::
string
&
carrier_id
,
void
Init
(
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
(
const
std
::
string
&
carrier_id
);
void
Run
(
const
std
::
string
&
carrier_id
);
...
@@ -47,14 +47,11 @@ class FleetExecutor final {
...
@@ -47,14 +47,11 @@ class FleetExecutor final {
private:
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
void
InitMessageBus
();
void
InitMessageBus
();
void
InitCarrier
(
Carrier
*
carrier
);
void
InitCarrier
(
Carrier
*
carrier
,
framework
::
Scope
*
scope
,
void
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
);
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
);
FleetExecutorDesc
exe_desc_
;
FleetExecutorDesc
exe_desc_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
framework
::
Scope
*
root_scope_
;
framework
::
Scope
*
minibatch_scope_
;
platform
::
Place
place_
;
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
};
};
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto
浏览文件 @
c50c22b0
...
@@ -23,5 +23,4 @@ message RankInfo {
...
@@ -23,5 +23,4 @@ message RankInfo {
message
FleetExecutorDesc
{
message
FleetExecutorDesc
{
optional
int64
cur_rank
=
1
[
default
=
0
];
// Rank id of current processor
optional
int64
cur_rank
=
1
[
default
=
0
];
// Rank id of current processor
repeated
RankInfo
cluster_info
=
2
;
repeated
RankInfo
cluster_info
=
2
;
optional
int64
num_micro_batches
=
3
[
default
=
1
];
}
}
python/paddle/fluid/executor.py
浏览文件 @
c50c22b0
...
@@ -400,6 +400,23 @@ def _is_enable_standalone_executor():
...
@@ -400,6 +400,23 @@ def _is_enable_standalone_executor():
return
flag
return
flag
def
_prepare_fleet_executor
():
from
..distributed.fleet.proto
import
fleet_executor_desc_pb2
trainer_endpoints_str
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
,
""
)
trainer_endpoints
=
trainer_endpoints_str
.
split
(
','
)
fleet_exe_desc
=
fleet_executor_desc_pb2
.
FleetExecutorDesc
()
cur_rank
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
fleet_exe_desc
.
cur_rank
=
cur_rank
nrank
=
len
(
trainer_endpoints
)
for
rank
,
endpoint
in
enumerate
(
trainer_endpoints
):
rank_info
=
fleet_executor_desc_pb2
.
RankInfo
()
rank_info
.
rank
=
rank
rank_info
.
ip_port
=
endpoint
fleet_exe_desc
.
cluster_info
.
append
(
rank_info
)
fleet_exe
=
core
.
FleetExecutor
(
fleet_exe_desc
.
SerializeToString
())
return
fleet_exe
def
_get_strong_program_cache_key
(
program
,
feed
,
fetch_list
):
def
_get_strong_program_cache_key
(
program
,
feed
,
fetch_list
):
# NOTE(xiongkun) id(proram) may be duplicate. So add addition var_name as cache key.
# NOTE(xiongkun) id(proram) may be duplicate. So add addition var_name as cache key.
def
_get_varname_from_block
(
block
):
def
_get_varname_from_block
(
block
):
...
@@ -692,6 +709,8 @@ class Executor(object):
...
@@ -692,6 +709,8 @@ class Executor(object):
self
.
_enable_interpreter_core
=
_is_enable_standalone_executor
()
self
.
_enable_interpreter_core
=
_is_enable_standalone_executor
()
self
.
_executor_cache
=
_ExecutorCache
(
self
.
place
)
self
.
_executor_cache
=
_ExecutorCache
(
self
.
place
)
self
.
_fleet_executor
=
None
def
_get_scope_cache
(
self
,
program_cache_key
):
def
_get_scope_cache
(
self
,
program_cache_key
):
return
self
.
scope_caches
.
get
(
program_cache_key
,
None
)
return
self
.
scope_caches
.
get
(
program_cache_key
,
None
)
...
@@ -1281,6 +1300,9 @@ class Executor(object):
...
@@ -1281,6 +1300,9 @@ class Executor(object):
if
isinstance
(
program
,
Program
)
and
program
.
_pipeline_opt
:
if
isinstance
(
program
,
Program
)
and
program
.
_pipeline_opt
:
if
"fleet_opt"
in
program
.
_pipeline_opt
:
if
"fleet_opt"
in
program
.
_pipeline_opt
:
# Move prepare here for port conflict with nccl in startup program
if
self
.
_fleet_executor
is
None
:
self
.
_fleet_executor
=
_prepare_fleet_executor
()
return
self
.
_run_using_fleet_executor
(
return
self
.
_run_using_fleet_executor
(
program
=
program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
program
=
program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
if
"startup_program"
in
program
.
_pipeline_opt
:
if
"startup_program"
in
program
.
_pipeline_opt
:
...
@@ -1960,27 +1982,16 @@ class Executor(object):
...
@@ -1960,27 +1982,16 @@ class Executor(object):
return
ctx
return
ctx
def
_prepare_fleet_executor
(
self
,
def
_prepare_fleet_executor
_carrier
(
self
,
carrier_id
=
""
,
carrier_id
=
""
,
program
=
None
,
program
=
None
,
scope
=
None
,
scope
=
None
,
fleet_opt
=
None
):
fleet_opt
=
None
):
from
..distributed.fleet.proto
import
fleet_executor_desc_pb2
num_micro_batches
=
fleet_opt
[
assert
program
,
"Program for fleet executor should not be None"
"num_micro_batches"
]
if
"num_micro_batches"
in
fleet_opt
else
1
assert
fleet_opt
,
"Configurations for fleet executor should not be None"
trainer_endpoints_str
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
,
""
)
trainer_endpoints
=
trainer_endpoints_str
.
split
(
','
)
fleet_exe_desc
=
fleet_executor_desc_pb2
.
FleetExecutorDesc
()
cur_rank
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
cur_rank
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
fleet_exe_desc
.
cur_rank
=
cur_rank
trainer_endpoints
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
,
""
).
split
(
','
)
nrank
=
len
(
trainer_endpoints
)
nrank
=
len
(
trainer_endpoints
)
for
rank
,
endpoint
in
enumerate
(
trainer_endpoints
):
rank_info
=
fleet_executor_desc_pb2
.
RankInfo
()
rank_info
.
rank
=
rank
rank_info
.
ip_port
=
endpoint
fleet_exe_desc
.
cluster_info
.
append
(
rank_info
)
if
"num_micro_batches"
in
fleet_opt
:
fleet_exe_desc
.
num_micro_batches
=
fleet_opt
[
"num_micro_batches"
]
assert
'scheduler'
in
fleet_opt
or
'tasks'
in
fleet_opt
,
\
assert
'scheduler'
in
fleet_opt
or
'tasks'
in
fleet_opt
,
\
"Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin. "
\
"Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin. "
\
...
@@ -2019,12 +2030,10 @@ class Executor(object):
...
@@ -2019,12 +2030,10 @@ class Executor(object):
# NOTE: have to hold these vars, otherwise will be destructed
# NOTE: have to hold these vars, otherwise will be destructed
fleet_opt
[
'tasks'
]
=
tasks
fleet_opt
[
'tasks'
]
=
tasks
fleet_opt
[
'task_id_to_rank'
]
=
task_id_to_rank
fleet_opt
[
'task_id_to_rank'
]
=
task_id_to_rank
fleet_exe
=
core
.
FleetExecutor
(
fleet_exe_desc
.
SerializeToString
())
place
=
core
.
Place
()
place
=
core
.
Place
()
place
.
set_place
(
self
.
place
)
place
.
set_place
(
self
.
place
)
fleet_exe
.
init
(
carrier_id
,
program
.
desc
,
scope
,
place
,
tasks
,
self
.
_fleet_executor
.
init
(
carrier_id
,
program
.
desc
,
scope
,
place
,
task_id_to_rank
)
num_micro_batches
,
tasks
,
task_id_to_rank
)
return
fleet_exe
def
_run_using_fleet_executor
(
self
,
def
_run_using_fleet_executor
(
self
,
program
=
None
,
program
=
None
,
...
@@ -2032,16 +2041,15 @@ class Executor(object):
...
@@ -2032,16 +2041,15 @@ class Executor(object):
feed_var_name
=
"feed"
,
feed_var_name
=
"feed"
,
fetch_var_name
=
"fetch"
,
fetch_var_name
=
"fetch"
,
fetch_list
=
None
):
fetch_list
=
None
):
# TODO(liyurui): Change cache strategy for multi carriers
cache_key
=
_get_strong_program_cache_key
(
program
,
feed
,
fetch_list
)
cache_key
=
_get_strong_program_cache_key
(
program
,
feed
,
fetch_list
)
cached_ctx
=
self
.
_get_ctx_cache
(
cache_key
)
cached_scope
=
self
.
_get_scope_cache
(
cache_key
)
cached_program
=
self
.
_get_program_cache
(
cache_key
)
cached_program
=
self
.
_get_program_cache
(
cache_key
)
real_feed
=
[]
if
feed
is
None
else
feed
cached_scope
=
self
.
_get_scope_cache
(
cache_key
)
if
cached_scope
is
None
:
if
cached_scope
is
None
:
cached_scope
=
global_scope
()
cached_scope
=
global_scope
()
self
.
_add_scope_cache
(
cache_key
,
cached_scope
)
self
.
_add_scope_cache
(
cache_key
,
cached_scope
)
if
cached_program
is
None
:
if
cached_program
is
None
:
assert
program
.
_pipeline_opt
,
"program should have _pipeline_opt to start carrier"
real_feed
=
[]
if
feed
is
None
else
feed
real_program
=
program
real_program
=
program
if
"section_program"
in
program
.
_pipeline_opt
:
if
"section_program"
in
program
.
_pipeline_opt
:
real_program
=
program
.
_pipeline_opt
[
"section_program"
]
real_program
=
program
.
_pipeline_opt
[
"section_program"
]
...
@@ -2060,7 +2068,6 @@ class Executor(object):
...
@@ -2060,7 +2068,6 @@ class Executor(object):
'op_role'
,
'op_role'
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
)
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
)
self
.
_add_program_cache
(
cache_key
,
cached_program
)
self
.
_add_program_cache
(
cache_key
,
cached_program
)
if
cached_ctx
is
None
:
fleet_opt
=
program
.
_pipeline_opt
[
"fleet_opt"
]
fleet_opt
=
program
.
_pipeline_opt
[
"fleet_opt"
]
if
'tasks'
in
fleet_opt
:
if
'tasks'
in
fleet_opt
:
# Insert feed/fetch op for cloned program in each task node,
# Insert feed/fetch op for cloned program in each task node,
...
@@ -2097,12 +2104,12 @@ class Executor(object):
...
@@ -2097,12 +2104,12 @@ class Executor(object):
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
)
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
)
fetch_task
.
set_program
(
fetch_program
)
fetch_task
.
set_program
(
fetch_program
)
cached_ctx
=
self
.
_prepare_fleet_executo
r
(
self
.
_prepare_fleet_executor_carrie
r
(
cache_key
,
cache_key
,
program
=
cached_program
,
program
=
cached_program
,
scope
=
cached_scope
,
scope
=
cached_scope
,
fleet_opt
=
fleet_opt
)
fleet_opt
=
fleet_opt
)
self
.
_add_ctx_cache
(
cache_key
,
cached_ctx
)
if
feed
:
if
feed
:
# NOTE: don't have to traverse programs in task nodes,
# NOTE: don't have to traverse programs in task nodes,
# since they all sub program of cached program and
# since they all sub program of cached program and
...
@@ -2120,7 +2127,8 @@ class Executor(object):
...
@@ -2120,7 +2127,8 @@ class Executor(object):
lr_sheduler
.
_var_name
)
lr_sheduler
.
_var_name
)
tensor
.
set
(
data
,
self
.
place
)
tensor
.
set
(
data
,
self
.
place
)
cached_ctx
.
run
(
cache_key
)
self
.
_fleet_executor
.
run
(
cache_key
)
if
fetch_list
:
if
fetch_list
:
arr
=
cached_scope
.
find_var
(
fetch_var_name
).
get_fetch_list
()
arr
=
cached_scope
.
find_var
(
fetch_var_name
).
get_fetch_list
()
tensors
=
arr
.
_move_to_list
()
tensors
=
arr
.
_move_to_list
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录