Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a6344af2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a6344af2
编写于
9月 10, 2020
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update, test=develop
上级
f71543ee
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
501 addition
and
815 deletion
+501
-815
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+9
-8
paddle/fluid/framework/pipeline_trainer.cc
paddle/fluid/framework/pipeline_trainer.cc
+229
-142
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+208
-572
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+21
-13
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+1
-1
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+33
-79
未找到文件。
paddle/fluid/framework/device_worker.h
浏览文件 @
a6344af2
...
...
@@ -414,7 +414,8 @@ class HeterCpuWorker : public HogwildWorker {
#if defined(PADDLE_WITH_NCCL)
class
SectionWorker
:
public
DeviceWorker
{
public:
SectionWorker
()
{
local_batch_id_
=
0
;
}
// SectionWorker() { local_batch_id_ = 0; }
SectionWorker
()
{}
~
SectionWorker
()
override
{}
void
Initialize
(
const
TrainerDesc
&
desc
)
override
;
...
...
@@ -429,7 +430,7 @@ class SectionWorker : public DeviceWorker {
const
platform
::
Place
&
place
()
const
{
return
place_
;
}
void
SetSectionIndex
(
int
section_id
)
{
section_id_
=
section_id
;
}
//
void SetSectionIndex(int section_id) { section_id_ = section_id; }
void
SetDeviceIndex
(
int
tid
)
override
{}
void
SetThreadIndex
(
int
thread_id
)
{
thread_id_
=
thread_id
;
}
void
SetMicrobatchNum
(
int
num
)
{
num_microbatches_
=
num
;
}
...
...
@@ -440,7 +441,7 @@ class SectionWorker : public DeviceWorker {
void
SetSkipVars
(
const
std
::
vector
<
std
::
string
>&
skip_vars
)
{
skip_vars_
=
skip_vars
;
}
static
void
ResetBatchId
()
{
batch_id_
=
0
;
}
//
static void ResetBatchId() { batch_id_ = 0; }
static
std
::
atomic
<
int
>
cpu_id_
;
...
...
@@ -454,13 +455,13 @@ class SectionWorker : public DeviceWorker {
const
Scope
*
minibatch_scope_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
static
std
::
mutex
thread_mutex
;
static
std
::
mutex
cout_mutex
;
static
std
::
condition_variable
thread_condition
;
static
bool
threads_completed
;
//
static std::mutex thread_mutex;
//
static std::mutex cout_mutex;
//
static std::condition_variable thread_condition;
//
static bool threads_completed;
std
::
shared_ptr
<
framework
::
ProgramDesc
>
program_
;
static
uint64_t
batch_id_
;
uint64_t
local_batch_id_
;
//
uint64_t local_batch_id_;
platform
::
DeviceContext
*
dev_ctx_
=
nullptr
;
};
...
...
paddle/fluid/framework/pipeline_trainer.cc
浏览文件 @
a6344af2
...
...
@@ -27,73 +27,88 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
const
auto
&
section_params
=
trainer_desc
.
section_param
();
num_microbatches_
=
section_params
.
num_microbatches
();
VLOG
(
3
)
<<
"Number of microbatches per minibatch: "
<<
num_microbatches_
;
section_num_
=
section_params
.
section_config_size
();
VLOG
(
3
)
<<
"Number of program sections: "
<<
section_num_
;
trainer_desc_
=
trainer_desc
;
start_cpu_core_id_
=
section_params
.
start_cpu_core_id
();
SetDataset
(
dataset
);
//
SetDataset(dataset);
ParseDumpConfig
(
trainer_desc
);
// get filelist from trainer_desc here
const
std
::
vector
<
paddle
::
framework
::
DataFeed
*>
readers
=
dataset
->
GetReaders
();
VLOG
(
3
)
<<
"readers num: "
<<
readers
.
size
();
int
num_readers
=
readers
.
size
();
PADDLE_ENFORCE_EQ
(
num_readers
,
1
,
platform
::
errors
::
InvalidArgument
(
"Number of dataset readers for pipeline "
"must be 1 now, but the value you give is %d."
,
num_readers
));
auto
*
reader
=
readers
[
0
];
// const std::vector<paddle::framework::DataFeed*> readers =
// VLOG(3) << "Number of program sections: " << section_num_;
// dataset->GetReaders();
// VLOG(3) << "readers num: " << readers.size();
// int num_readers = readers.size();
// PADDLE_ENFORCE_EQ(num_readers, 1,
// platform::errors::InvalidArgument(
// "Number of dataset readers for pipeline "
// "must be 1 now, but the value you give is %d.",
// num_readers));
// auto* reader = readers[0];
workers_
.
resize
(
section_num_
);
for
(
int
i
=
0
;
i
<
section_num_
;
++
i
)
{
const
auto
&
section_config
=
section_params
.
section_config
(
i
);
platform
::
Place
place
;
int
place_id
=
section_config
.
place_id
();
switch
(
section_config
.
place
())
{
case
SectionConfig
::
CPUPlace
:
place
=
platform
::
CPUPlace
();
break
;
case
SectionConfig
::
CUDAPlace
:
// Note that one section has at most one GPU place in one pipeline
PADDLE_ENFORCE_GE
(
place_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The place_id value for CUDAPlace shoud be greater "
"than or equal to 0, but the value you give is %d."
,
place_id
));
place
=
platform
::
CUDAPlace
(
place_id
);
break
;
case
SectionConfig
::
CUDAPinnedPlace
:
place
=
platform
::
CUDAPinnedPlace
();
break
;
default:
PADDLE_ENFORCE_NOT_NULL
(
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Unkown place type in SectionConfig: %d"
,
section_config
.
place
()));
}
places_
.
emplace_back
(
place
);
VLOG
(
3
)
<<
"Device worker place: "
<<
place
<<
", device id: "
<<
place_id
<<
", section: "
<<
i
;
// workers_.resize(section_num_);
// for (int i = 0; i < section_num_; ++i) {
// const auto& section_config = section_params.section_config(i);
// platform::Place place;
// int place_id = section_config.place_id();
// switch (section_config.place()) {
// case SectionConfig::CPUPlace:
// place = platform::CPUPlace();
// break;
// case SectionConfig::CUDAPlace:
// // Note that one section has at most one GPU place in one pipeline
// PADDLE_ENFORCE_GE(
// place_id, 0,
// platform::errors::InvalidArgument(
// "The place_id value for CUDAPlace shoud be greater "
// "than or equal to 0, but the value you give is %d.",
// place_id));
// place = platform::CUDAPlace(place_id);
// break;
// case SectionConfig::CUDAPinnedPlace:
// place = platform::CUDAPinnedPlace();
// break;
// default:
// PADDLE_ENFORCE_NOT_NULL(nullptr,
// platform::errors::InvalidArgument(
// "Unkown place type in SectionConfig: %d",
// section_config.place()));
// }
// places_.emplace_back(place);
// VLOG(3) << "Device worker place: " << place << ", device id: " << place_id
// << ", section: " << i;
// workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
// trainer_desc.device_worker_name());
// auto this_worker =
// std::dynamic_pointer_cast<paddle::framework::SectionWorker>(
// workers_[i]);
// if (i == 0) {
// // we only set reader for the first section
// this_worker->SetDataFeed(reader);
// this_worker->SetReaderPlace(place);
// }
// this_worker->SetThreadIndex(i);
// this_worker->SetSectionIndex(i);
// this_worker->SetPlace(place);
// this_worker->Initialize(trainer_desc);
// this_worker->SetMicrobatchNum(num_microbatches_);
//}
const
auto
&
section_config
=
section_params
.
section_config
();
int
place_id
=
section_config
.
place_id
();
PADDLE_ENFORCE_GE
(
place_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The place_id value for CUDAPlace shoud be "
"non-negative, but the value given is %d."
,
place_id
));
place_
=
platform
::
CUDAPlace
(
place_id
);
worker_
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
auto
this_worker
=
std
::
dynamic_pointer_cast
<
paddle
::
framework
::
SectionWorker
>
(
worker_
);
this_worker
->
SetPlace
(
place_
);
this_worker
->
Initialize
(
trainer_desc
);
this_worker
->
SetMicrobatchNum
(
num_microbatches_
);
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
auto
this_worker
=
std
::
dynamic_pointer_cast
<
paddle
::
framework
::
SectionWorker
>
(
workers_
[
i
]);
if
(
i
==
0
)
{
// we only set reader for the first section
this_worker
->
SetDataFeed
(
reader
);
this_worker
->
SetReaderPlace
(
place
);
}
this_worker
->
SetThreadIndex
(
i
);
this_worker
->
SetSectionIndex
(
i
);
this_worker
->
SetPlace
(
place
);
this_worker
->
Initialize
(
trainer_desc
);
this_worker
->
SetMicrobatchNum
(
num_microbatches_
);
}
// set debug here
SetDebug
(
trainer_desc
.
debug
());
}
...
...
@@ -119,7 +134,52 @@ void PipelineTrainer::InitDumpEnv() {
}
}
void
PipelineTrainer
::
CopyParameters
(
int
section_id
,
int
microbatch_id
,
// void PipelineTrainer::CopyParameters(int section_id, int microbatch_id,
// const ProgramDesc& program,
// const platform::Place& place) {
// auto& global_block = program.Block(0);
// std::map<std::string, int> param_map;
// for (auto& var : global_block.AllVars()) {
// if (var->Persistable()) {
// param_map[var->Name()] = 1;
// }
// }
// for (auto& var : global_block.AllVars()) {
// bool is_param_grad = false;
// size_t pos = 0;
// if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos) {
// auto prefix_name = var->Name().substr(0, pos);
// if (param_map.find(prefix_name) != param_map.end()) {
// is_param_grad = true;
// }
// }
// VLOG(3) << "Var name: " << var->Name();
// if ((var->Persistable() || is_param_grad) && microbatch_id == 0) {
// auto* ptr = root_scope_->FindVar(var->Name());
// auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name());
// VLOG(3) << "Create persistable var " << var->Name() << " for minibatch
// "
// << section_id << ", which pointer is " << new_ptr;
// InitializeVariable(new_ptr, var->GetType());
// if (is_param_grad) {
// continue;
// }
// const LoDTensor& root_tensor = ptr->Get<LoDTensor>();
// LoDTensor* minibatch_tensor = new_ptr->GetMutable<LoDTensor>();
// TensorCopy(*static_cast<const Tensor*>(&root_tensor), place,
// static_cast<Tensor*>(minibatch_tensor));
// } else if (!var->Persistable() && !is_param_grad) {
// auto* ptr =
// microbatch_scopes_[section_id][microbatch_id]->Var(var->Name());
// VLOG(3) << "Create variable " << var->Name() << " for section "
// << section_id << " microbatch " << microbatch_id
// << ", which pointer is " << ptr;
// InitializeVariable(ptr, var->GetType());
// }
// }
// }
void
PipelineTrainer
::
CopyParameters
(
int
microbatch_id
,
const
ProgramDesc
&
program
,
const
platform
::
Place
&
place
)
{
auto
&
global_block
=
program
.
Block
(
0
);
...
...
@@ -139,45 +199,57 @@ void PipelineTrainer::CopyParameters(int section_id, int microbatch_id,
}
}
VLOG
(
3
)
<<
"Var name: "
<<
var
->
Name
();
if
((
var
->
Persistable
()
||
is_param_grad
)
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
FindVar
(
var
->
Name
());
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create persistable var "
<<
var
->
Name
()
<<
" for minibatch "
<<
section_id
<<
", which pointer is "
<<
new_ptr
;
InitializeVariable
(
new_ptr
,
var
->
GetType
());
if
(
is_param_grad
)
{
continue
;
}
const
LoDTensor
&
root_tensor
=
ptr
->
Get
<
LoDTensor
>
();
LoDTensor
*
minibatch_tensor
=
new_ptr
->
GetMutable
<
LoDTensor
>
();
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
root_tensor
),
place
,
static_cast
<
Tensor
*>
(
minibatch_tensor
));
if
(
is_param_grad
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
minibatch_scope_
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create grad for persistable var: "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
else
if
(
!
var
->
Persistable
()
&&
!
is_param_grad
)
{
auto
*
ptr
=
microbatch_scopes_
[
section_id
][
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for section "
<<
section_id
<<
" microbatch "
<<
microbatch_id
auto
*
ptr
=
microbatch_scopes_
[
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
" microbatch "
<<
", which pointer is "
<<
ptr
;
InitializeVariable
(
ptr
,
var
->
GetType
());
}
}
}
void
PipelineTrainer
::
GetSkipVars
(
int
section_id
,
const
ProgramDesc
&
program
)
{
// void PipelineTrainer::GetSkipVars(int section_id, const ProgramDesc& program)
// {
// auto& global_block = program.Block(0);
// for (auto& op : global_block.AllOps()) {
// if (op->Type() != "enqueue") {
// continue;
// }
// auto input_arg_names = op->InputArgumentNames();
// PADDLE_ENFORCE_EQ(input_arg_names.size(), 1,
// platform::errors::InvalidArgument(
// "Number of input arguments for enqueue op must be
// 1, "
// "but the value is %d.",
// input_arg_names.size()));
// std::string input_arg_name = input_arg_names[0];
// if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) {
// skip_vars_[section_id].emplace_back(input_arg_name);
// VLOG(3) << "add skip var name: " << input_arg_name;
// }
// }
// }
void
PipelineTrainer
::
GetSkipVars
(
const
ProgramDesc
&
program
)
{
auto
&
global_block
=
program
.
Block
(
0
);
for
(
auto
&
op
:
global_block
.
AllOps
())
{
if
(
op
->
Type
()
!=
"
enqueue
"
)
{
if
(
op
->
Type
()
!=
"
c_send
"
)
{
continue
;
}
auto
input_arg_names
=
op
->
InputArgumentNames
();
PADDLE_ENFORCE_EQ
(
input_arg_names
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Number of input arguments for
enqueue
op must be 1, "
"but the value is %d."
,
"Number of input arguments for
c_send
op must be 1, "
"but the value
given
is %d."
,
input_arg_names
.
size
()));
std
::
string
input_arg_name
=
input_arg_names
[
0
];
if
(
input_arg_name
.
rfind
(
"@GRAD"
)
!=
input_arg_name
.
size
()
-
5
)
{
skip_vars_
[
section_id
]
.
emplace_back
(
input_arg_name
);
skip_vars_
.
emplace_back
(
input_arg_name
);
VLOG
(
3
)
<<
"add skip var name: "
<<
input_arg_name
;
}
}
...
...
@@ -185,86 +257,101 @@ void PipelineTrainer::GetSkipVars(int section_id, const ProgramDesc& program) {
void
PipelineTrainer
::
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
)
{
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
platform
::
errors
::
InvalidArgument
(
"root_scope pointer can not be nullptr"
));
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
platform
::
errors
::
InvalidArgument
(
"root_scope_ can not be nullptr"
));
auto
start_cpu_id
=
trainer_desc_
.
section_param
().
start_cpu_core_id
();
SectionWorker
::
cpu_id_
.
store
(
start_cpu_id
);
minibatch_scopes_
.
resize
(
section_num_
);
microbatch_scopes_
.
resize
(
section_num_
);
skip_vars_
.
resize
(
section_num_
);
// minibatch_scopes_.resize(section_num_);
// microbatch_scopes_.resize(section_num_);
// minibatch_scopes_.resize(1);
microbatch_scopes_
.
resize
(
num_microbatches_
);
// skip_vars_.resize(section_num_);
VLOG
(
3
)
<<
"Init ScopeQueues and create all scopes"
;
for
(
int
i
=
0
;
i
<
section_num_
;
++
i
)
{
minibatch_scopes_
[
i
]
=
&
root_scope_
->
NewScope
();
std
::
shared_ptr
<
framework
::
ProgramDesc
>
program
;
program
.
reset
(
new
ProgramDesc
(
trainer_desc_
.
section_param
().
section_config
(
i
).
program_desc
()));
microbatch_scopes_
[
i
].
resize
(
num_microbatches_
);
for
(
int
j
=
0
;
j
<
num_microbatches_
;
++
j
)
{
microbatch_scopes_
[
i
][
j
]
=
&
minibatch_scopes_
[
i
]
->
NewScope
();
CopyParameters
(
i
,
j
,
*
program
,
places_
[
i
]);
}
GetSkipVars
(
i
,
*
program
);
// for (int i = 0; i < section_num_; ++i) {
minibatch_scope_
=
&
root_scope_
->
NewScope
();
std
::
shared_ptr
<
framework
::
ProgramDesc
>
program
;
program
.
reset
(
new
ProgramDesc
(
trainer_desc_
.
section_param
().
section_config
().
program_desc
()));
// trainer_desc_.section_param().section_config(i).program_desc()));
// microbatch_scopes_[i].resize(num_microbatches_);
for
(
int
j
=
0
;
j
<
num_microbatches_
;
++
j
)
{
// microbatch_scopes_[j] = &minibatch_scopes_[i]->NewScope();
microbatch_scopes_
[
j
]
=
&
minibatch_scope_
->
NewScope
();
// CopyParameters(i, j, *program, places_[i]);
CopyParameters
(
j
,
*
program
,
place_
);
}
// GetSkipVars(i, *program);
GetSkipVars
(
*
program
);
// }
for
(
int
i
=
0
;
i
<
section_num_
;
++
i
)
{
auto
this_worker
=
std
::
dynamic_pointer_cast
<
paddle
::
framework
::
SectionWorker
>
(
workers_
[
i
]);
this_worker
->
SetRootScope
(
root_scope_
);
this_worker
->
SetMinibatchScope
(
minibatch_scopes_
[
i
]);
this_worker
->
SetMicrobatchScopes
(
microbatch_scopes_
[
i
]);
this_worker
->
SetSkipVars
(
skip_vars_
[
i
]);
}
// for (int i = 0; i < section_num_; ++i) {
auto
this_worker
=
std
::
dynamic_pointer_cast
<
paddle
::
framework
::
SectionWorker
>
(
worker_
);
// workers_[i]);
this_worker
->
SetRootScope
(
root_scope_
);
this_worker
->
SetMinibatchScope
(
minibatch_scope_
);
// this_worker->SetMicrobatchScopes(microbatch_scopes_[i]);
this_worker
->
SetMicrobatchScopes
(
microbatch_scopes_
);
// this_worker->SetSkipVars(skip_vars_[i]);
//}
}
void
PipelineTrainer
::
Run
()
{
VLOG
(
3
)
<<
"Going to run"
;
for
(
int
i
=
0
;
i
<
section_num_
;
++
i
)
{
if
(
!
debug_
)
{
section_threads_
.
push_back
(
std
::
thread
(
&
DeviceWorker
::
TrainFiles
,
workers_
[
i
].
get
()));
}
else
{
section_threads_
.
push_back
(
std
::
thread
(
&
DeviceWorker
::
TrainFilesWithProfiler
,
workers_
[
i
].
get
()));
}
// for (int i = 0; i < section_num_; ++i) {
if
(
!
debug_
)
{
section_thread_
=
std
::
thread
(
&
DeviceWorker
::
TrainFiles
,
worker_
.
get
());
// section_threads_.push_back(
// std::thread(&DeviceWorker::TrainFiles, workers_.get()));
// std::thread(&DeviceWorker::TrainFiles, workers_[i].get()));
}
else
{
section_thread_
=
std
::
thread
(
&
DeviceWorker
::
TrainFilesWithProfiler
,
worker_
.
get
());
// section_threads_.push_back(std::thread(
// &DeviceWorker::TrainFilesWithProfiler, workers_.get()));
// &DeviceWorker::TrainFilesWithProfiler, workers_[i].get()));
}
//}
}
void
PipelineTrainer
::
Finalize
()
{
for
(
auto
&
th
:
section_threads_
)
{
th
.
join
();
}
// for (auto& th : section_threads_) {
// th.join();
//}
section_thread_
.
join
();
if
(
need_dump_field_
)
{
FinalizeDumpEnv
();
}
VLOG
(
3
)
<<
"copying back parameters. "
;
for
(
int
i
=
0
;
i
<
section_num_
;
++
i
)
{
std
::
shared_ptr
<
framework
::
ProgramDesc
>
program
;
program
.
reset
(
new
ProgramDesc
(
trainer_desc_
.
section_param
().
section_config
(
i
).
program_desc
()));
for
(
int
j
=
0
;
j
<
num_microbatches_
;
++
j
)
{
auto
&
global_block
=
program
->
Block
(
0
);
for
(
auto
&
var
:
global_block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
root_scope_
->
FindVar
(
var
->
Name
());
LoDTensor
*
root_tensor
=
ptr
->
GetMutable
<
LoDTensor
>
();
auto
*
minibatch_ptr
=
minibatch_scopes_
[
i
]
->
Var
(
var
->
Name
());
const
LoDTensor
&
minibatch_tensor
=
minibatch_ptr
->
Get
<
LoDTensor
>
();
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
minibatch_tensor
),
places_
[
0
],
static_cast
<
Tensor
*>
(
root_tensor
));
VLOG
(
3
)
<<
"Copy persitable var "
<<
var
->
Name
()
<<
" to root scope"
;
}
}
}
}
// VLOG(3) << "copying back parameters. ";
// for (int i = 0; i < section_num_; ++i) {
// std::shared_ptr<framework::ProgramDesc> program;
// program.reset(new ProgramDesc(
// trainer_desc_.section_param().section_config(i).program_desc()));
// for (int j = 0; j < num_microbatches_; ++j) {
// auto& global_block = program->Block(0);
// for (auto& var : global_block.AllVars()) {
// if (var->Persistable()) {
// auto* ptr = root_scope_->FindVar(var->Name());
// LoDTensor* root_tensor = ptr->GetMutable<LoDTensor>();
// auto* minibatch_ptr = minibatch_scopes_[i]->Var(var->Name());
// const LoDTensor& minibatch_tensor =
// minibatch_ptr->Get<LoDTensor>();
// TensorCopy(*static_cast<const Tensor*>(&minibatch_tensor),
// places_[0],
// static_cast<Tensor*>(root_tensor));
// VLOG(3) << "Copy persitable var " << var->Name() << " to root
// scope";
// }
// }
// }
// }
root_scope_
->
DropKids
();
SectionWorker
::
ResetBatchId
();
//
SectionWorker::ResetBatchId();
}
Scope
*
PipelineTrainer
::
GetWorkerScope
(
int
thread_id
)
{
return
microbatch_scopes_
[
thread_id
][
0
];
return
microbatch_scopes_
[
0
];
}
}
// end namespace framework
...
...
paddle/fluid/framework/section_worker.cc
浏览文件 @
a6344af2
此差异已折叠。
点击以展开。
paddle/fluid/framework/trainer.h
浏览文件 @
a6344af2
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
...
...
@@ -217,28 +218,35 @@ class PipelineTrainer : public TrainerBase {
virtual
Scope
*
GetWorkerScope
(
int
thread_id
);
void
InitDumpEnv
()
override
;
virtual
std
::
string
GetDumpPath
(
int
tid
);
void
GetSkipVars
(
int
section_id
,
const
ProgramDesc
&
main_program
);
void
GetSkipVars
(
const
ProgramDesc
&
main_program
);
protected:
int
section_num_
;
//
int section_num_;
int
num_microbatches_
;
int
start_cpu_core_id_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
skip_vars_
;
// std::vector<platform::Place> places_;
platform
::
Place
place_
;
// std::vector<std::vector<std::string>> skip_vars_;
std
::
vector
<
std
::
string
>
skip_vars_
;
TrainerDesc
trainer_desc_
;
std
::
vector
<
std
::
thread
>
section_threads_
;
// std::vector<std::thread> section_threads_;
std
::
thread
section_thread_
;
// worker: [section_id]
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DeviceWorker
>>
workers_
;
// std::vector<std::shared_ptr<paddle::framework::DeviceWorker>> workers_;
std
::
shared_ptr
<
paddle
::
framework
::
DeviceWorker
>
worker_
;
// minibatch_scopes_: [section_id]
std
::
vector
<
Scope
*>
minibatch_scopes_
;
// std::vector<Scope*> minibatch_scopes_;
Scope
*
minibatch_scope_
;
// microbatch_scopes_: [section_id][microbatch_id]
std
::
vector
<
std
::
vector
<
Scope
*>>
microbatch_scopes_
;
void
CopyParameters
(
int
section_id
,
int
microbatch_id
,
const
ProgramDesc
&
program
,
const
platform
::
Place
&
place
);
bool
isPersistableVarGrad
(
std
::
string
name
);
bool
isPersistable
(
VarDesc
*
var
);
// std::vector<std::vector<Scope*>> microbatch_scopes_;
// microbatch_scopes_: [microbatch_id]
std
::
vector
<
Scope
*>
microbatch_scopes_
;
void
CopyParameters
(
int
microbatch_id
,
const
ProgramDesc
&
program
,
const
platform
::
Place
&
place
);
// bool isPersistableVarGrad(std::string name);
// bool isPersistable(VarDesc* var);
};
#endif
...
...
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
a6344af2
...
...
@@ -84,7 +84,7 @@ message DownpourWorkerParameter {
}
message
SectionWorkerParameter
{
repeated
SectionConfig
section_config
=
1
;
SectionConfig
section_config
=
1
;
optional
int32
queue_size
=
2
[
default
=
1
];
optional
int64
sync_steps
=
3
[
default
=
1
];
optional
int32
start_cpu_core_id
=
4
[
default
=
1
];
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
a6344af2
...
...
@@ -3784,6 +3784,7 @@ class PipelineOptimizer(object):
Args:
main_program (Program): the main program
devices: all used devices
"""
programs
=
[]
# Map from device to its corresponding section program info
...
...
@@ -3910,10 +3911,10 @@ class PipelineOptimizer(object):
data_devices_map
[
var_name
].
append
(
dev_spec
)
return
data_devices_map
def
_insert_
enq_deq
_for_data_var
(
self
,
main_block
,
programs
,
startup
,
devices
):
def
_insert_
sendrecv
_for_data_var
(
self
,
main_block
,
programs
,
startup
,
devices
):
"""
Insert
enqueue and dequeue
ops for data var that on other devices.
Insert
send and recv
ops for data var that on other devices.
Args:
main_block (Block): Global block for main program
...
...
@@ -3926,39 +3927,24 @@ class PipelineOptimizer(object):
first_prog
=
programs
[
0
][
'program'
]
first_block
=
first_prog
.
block
(
0
)
enqueue
_index
=
0
insert
_index
=
0
for
op
in
first_block
.
ops
:
enqueue
_index
+=
1
insert
_index
+=
1
if
op
.
type
==
"read"
:
break
first_dev_spec
=
devices
[
0
]
for
var_name
in
data_devices_map
.
keys
():
for
device
in
data_devices_map
[
var_name
]:
if
device
==
first_dev_spec
:
continue
# step1: generate queue for each pair of data var and device
# that that data on
queue_name
=
var_name
+
"_blocking_queue"
queue_name
=
unique_name
.
generate
(
queue_name
)
queue_var
=
startup
.
block
(
0
).
create_var
(
name
=
queue_name
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
startup
.
block
(
0
).
append_op
(
type
=
'queue_generator'
,
attrs
=
{
'names'
:
[
queue_name
],
'capacity'
:
self
.
_num_microbatches
})
main_var
=
main_block
.
var
(
var_name
)
assert
main_var
.
is_data
if
not
var_name
in
first_block
.
vars
:
self
.
_create_var
(
first_block
,
main_var
,
var_name
)
first_block
.
_insert_op
(
index
=
enqueue
_index
,
type
=
'
enqueue
'
,
index
=
insert
_index
,
type
=
'
c_send
'
,
inputs
=
{
'X'
:
first_block
.
var
(
var_name
)},
attrs
=
{
'queue_name'
:
queue_name
,
self
.
_op_device_key
:
first_dev_spec
,
self
.
_op_role_key
:
self
.
_op_role
.
Forward
})
...
...
@@ -3972,12 +3958,11 @@ class PipelineOptimizer(object):
new_var
=
self
.
_create_var
(
block
,
source_var
,
var_name
)
block
.
_insert_op
(
index
=
0
,
type
=
'
dequeue
'
,
type
=
'
c_recv
'
,
outputs
=
{
'Out'
:
[
new_var
]},
attrs
=
{
self
.
_op_device_key
:
device
,
self
.
_op_role_key
:
self
.
_op_role
.
Forward
,
'queue_name'
:
queue_name
,
})
def
_strip_grad_suffix
(
self
,
name
):
...
...
@@ -4080,23 +4065,22 @@ class PipelineOptimizer(object):
assert
sorted_device_specs
==
device_specs
return
device_specs
def
_insert_enq_deq_ops_for_boundaries
(
self
,
block
,
origin_block
,
startup_program
):
def
_insert_sendrecv_ops_for_boundaries
(
self
,
block
,
origin_block
):
"""
Insert a pair of
enqueue and dequeue
ops for every two
Insert a pair of
send and recv
ops for every two
consecutive ops on different devices.
"""
startup_block
=
startup_program
.
global_block
()
extra_index
=
0
# A map from var to device spec where op takes it as input,
# avoiding multiple
enqueue and dequeue
ops.
# avoiding multiple
send and recv
ops.
var_devspec
=
dict
()
for
index
,
op
in
list
(
enumerate
(
origin_block
.
ops
)):
# skips lr-related op and vars, as we will process them later.
# skips lr-related op
s
and vars, as we will process them later.
if
int
(
op
.
attr
(
self
.
_op_role_key
))
&
int
(
self
.
_op_role
.
LRSched
):
continue
# skips update ops and vars, as we will process them later.
if
self
.
_is_update_op
(
op
):
continue
cur_device_spec
=
op
.
attr
(
self
.
_op_device_key
)
...
...
@@ -4119,37 +4103,23 @@ class PipelineOptimizer(object):
if
cur_device_spec
in
var_devspec
[
var_name
]:
continue
var_devspec
[
var_name
].
append
(
cur_device_spec
)
queue_name
=
var_name
+
"_blocking_queue"
queue_name
=
unique_name
.
generate
(
queue_name
)
queue_var
=
startup_block
.
create_var
(
name
=
queue_name
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
startup_block
.
append_op
(
type
=
'queue_generator'
,
attrs
=
{
'names'
:
[
queue_name
],
'capacity'
:
self
.
_num_microbatches
})
op_role
=
op
.
all_attrs
()[
self
.
_op_role_key
]
var
=
block
.
vars
[
var_name
]
block
.
_insert_op
(
index
=
index
+
extra_index
,
type
=
'
enqueue
'
,
type
=
'
c_send
'
,
inputs
=
{
'X'
:
var
},
attrs
=
{
'queue_name'
:
queue_name
,
self
.
_op_device_key
:
prev_device_spec
,
self
.
_op_role_key
:
op_role
})
extra_index
+=
1
block
.
_insert_op
(
index
=
index
+
extra_index
,
type
=
'
dequeue
'
,
type
=
'
c_recv
'
,
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
self
.
_op_device_key
:
cur_device_spec
,
'queue_name'
:
queue_name
,
self
.
_op_role_key
:
op_role
})
extra_index
+=
1
...
...
@@ -4178,7 +4148,9 @@ class PipelineOptimizer(object):
def
_accumulate_gradients
(
self
,
block
):
"""
Accumulate the graident generated in microbatch to the one in mini-batch.
Accumulate the gradients generated in microbatch to the one in mini-batch.
We also scale the loss corresponding to number of micro-batches at
the same time.
"""
for
index
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
offset
=
index
...
...
@@ -4210,12 +4182,10 @@ class PipelineOptimizer(object):
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
grad_name
=
op_role_var
[
i
+
1
]
grad_var
=
block
.
vars
[
grad_name
]
param_name
=
op_role_var
[
i
]
param_var
=
block
.
vars
[
param_name
]
new_var_name
=
unique_name
.
generate
(
param_name
)
new_var_name
=
self
.
_append_grad_suffix
(
new_var_name
)
new_var
=
self
.
_create_var
(
block
,
grad_var
,
new_var_name
)
self
.
_rename_arg
(
op
,
grad_name
,
new_var_name
)
new_grad_var_name
=
unique_name
.
generate
(
grad_name
)
new_var
=
self
.
_create_var
(
block
,
grad_var
,
new_grad_var_name
)
self
.
_rename_arg
(
op
,
grad_name
,
new_grad_var_name
)
block
.
_insert_op
(
index
=
offset
+
1
,
type
=
'sum'
,
...
...
@@ -4247,7 +4217,6 @@ class PipelineOptimizer(object):
def
_get_device_info
(
self
,
block
):
for
op
in
block
.
ops
:
if
not
op
.
_has_kernel
(
op
.
type
):
continue
op_device
=
op
.
attr
(
self
.
_op_device_key
)
return
op_device
...
...
@@ -4282,7 +4251,7 @@ class PipelineOptimizer(object):
for
prog
in
var_info
[
var_name
]:
block
=
prog
.
block
(
0
)
for
op
in
block
.
ops
:
if
op
.
type
==
"
dequeue
"
:
continue
if
op
.
type
==
"
c_recv
"
:
continue
# We have processed lr related vars
if
op
.
attr
(
self
.
_op_role_key
)
==
int
(
self
.
_op_role
.
Optimize
.
LRSched
):
...
...
@@ -4306,24 +4275,11 @@ class PipelineOptimizer(object):
for
prog
in
all_progs
:
if
prog
==
write_prog
:
continue
queue_name
=
var_name
+
"_blocking_queue"
queue_name
=
unique_name
.
generate
(
queue_name
)
queue_var
=
startup_prog
.
block
(
0
).
create_var
(
name
=
queue_name
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
startup_prog
.
block
(
0
).
append_op
(
type
=
'queue_generator'
,
attrs
=
{
'names'
:
[
queue_name
],
'capacity'
:
self
.
_num_microbatches
})
write_block
.
_insert_op
(
index
=
0
,
type
=
'
enqueue
'
,
type
=
'
c_send
'
,
inputs
=
{
'X'
:
write_block
.
var
(
var_name
),
},
attrs
=
{
'queue_name'
:
queue_name
,
self
.
_op_device_key
:
write_device
,
# A trick to make the role LRSched to avoid copy every
# microbatch
...
...
@@ -4333,14 +4289,13 @@ class PipelineOptimizer(object):
read_device
=
self
.
_get_device_info
(
read_block
)
read_block
.
_insert_op
(
index
=
0
,
type
=
'
dequeue
'
,
type
=
'
c_recv
'
,
outputs
=
{
'Out'
:
[
read_block
.
var
(
var_name
)]},
attrs
=
{
self
.
_op_device_key
:
read_device
,
# A trick to make the role LRSched to avoid copy every
# microbatch
self
.
_op_role_key
:
self
.
_op_role
.
LRSched
,
'queue_name'
:
queue_name
,
})
def
minimize
(
self
,
...
...
@@ -4365,14 +4320,13 @@ class PipelineOptimizer(object):
device_specs
=
self
.
_check_validation
(
main_block
)
# Step3: add
enqueue and dequeue
ops between section boundaries
# Step3: add
send and recv
ops between section boundaries
origin_prog
=
main_block
.
program
.
clone
(
for_test
=
False
)
origin_main_block
=
origin_prog
.
global_block
()
self
.
_insert_enq_deq_ops_for_boundaries
(
main_block
,
origin_main_block
,
startup_program
)
self
.
_insert_sendrecv_ops_for_boundaries
(
main_block
,
origin_main_block
)
# Step4:
accumulate gradients during backward
# a
nd clear them after update
# Step4:
clear gradients before each mini-batch and
# a
ccumulate gradients during backward
self
.
_clear_gradients
(
main_block
)
self
.
_accumulate_gradients
(
main_block
)
...
...
@@ -4392,14 +4346,14 @@ class PipelineOptimizer(object):
raise
ValueError
(
"Unknown device type: %s"
,
dev_spec
)
# Step5: split program into sections and add pairs of
#
enqueue and dequeue
ops for data var.
#
send and recv
ops for data var.
if
len
(
place_list
)
<=
1
:
raise
ValueError
(
"Run on one device, do not use pipeline."
)
program_list
=
self
.
_split_program
(
main_program
,
device_specs
)
for
p
in
program_list
:
self
.
_create_vars
(
p
[
"program"
].
block
(
0
),
main_program
)
self
.
_insert_
enq_deq
_for_data_var
(
main_block
,
program_list
,
startup_program
,
device_specs
)
self
.
_insert_
sendrecv
_for_data_var
(
main_block
,
program_list
,
startup_program
,
device_specs
)
# Step6: Special Case: process persistable vars that exist in
# multiple sections
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录