Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleRec
提交
c1c5c20d
P
PaddleRec
项目概览
PaddlePaddle
/
PaddleRec
通知
68
Star
12
Fork
5
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
27
列表
看板
标记
里程碑
合并请求
10
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
27
Issue
27
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1c5c20d
编写于
8月 01, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
for runnable trainer
上级
29aec8e0
变更
19
展开全部
隐藏空白更改
内联
并排
Showing
19 changed file
with
15379 addition
and
24 deletion
+15379
-24
BCLOUD
BCLOUD
+3
-1
paddle/fluid/train/custom_trainer/feed/accessor/accessor.h
paddle/fluid/train/custom_trainer/feed/accessor/accessor.h
+19
-0
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+58
-0
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+50
-0
paddle/fluid/train/custom_trainer/feed/accessor/me
paddle/fluid/train/custom_trainer/feed/accessor/me
+14840
-0
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
...d/train/custom_trainer/feed/common/runtime_environment.cc
+34
-0
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
...id/train/custom_trainer/feed/common/runtime_environment.h
+31
-7
paddle/fluid/train/custom_trainer/feed/executor/executor.h
paddle/fluid/train/custom_trainer/feed/executor/executor.h
+14
-11
paddle/fluid/train/custom_trainer/feed/main.cc
paddle/fluid/train/custom_trainer/feed/main.cc
+3
-3
paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.h
paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.h
+38
-0
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
+46
-0
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
...uid/train/custom_trainer/feed/process/init_env_process.cc
+28
-0
paddle/fluid/train/custom_trainer/feed/process/init_env_process.h
...luid/train/custom_trainer/feed/process/init_env_process.h
+1
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+145
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
...fluid/train/custom_trainer/feed/process/learner_process.h
+40
-0
paddle/fluid/train/custom_trainer/feed/process/process.cc
paddle/fluid/train/custom_trainer/feed/process/process.cc
+2
-0
paddle/fluid/train/custom_trainer/feed/process/process.h
paddle/fluid/train/custom_trainer/feed/process/process.h
+6
-1
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+20
-0
publish_include.sh
publish_include.sh
+1
-1
未找到文件。
BCLOUD
浏览文件 @
c1c5c20d
...
...
@@ -70,10 +70,12 @@ SharedLibrary('paddle_fluid_avx_mklml', Sources(paddle_fluid_avx_mklml_src, Cpp
#feed
HEADERS('paddle/fluid/train/custom_trainer/feed/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/')
HEADERS('paddle/fluid/train/custom_trainer/feed/common/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/common/')
HEADERS('paddle/fluid/train/custom_trainer/feed/executor/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/executor/')
HEADERS('paddle/fluid/train/custom_trainer/feed/monitor/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/monitor/')
HEADERS('paddle/fluid/train/custom_trainer/feed/dataset/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/dataset/')
HEADERS('paddle/fluid/train/custom_trainer/feed/process/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/process/')
HEADERS('paddle/fluid/train/custom_trainer/feed/shuffler/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/shuffler/')
HEADERS('paddle/fluid/train/custom_trainer/feed/
params_accessor/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/params_
accessor/')
HEADERS('paddle/fluid/train/custom_trainer/feed/
accessor/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/
accessor/')
NEED_OUTPUT("baidu/third-party/mklml")
OUTPUT('paddle/fluid/train/custom_trainer/feed/conf', '$OUT')
OUTPUT('paddle/fluid/train/custom_trainer/feed/scripts', '$OUT')
...
...
paddle/fluid/train/custom_trainer/feed/accessor/accessor.h
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
Accessor
{
public:
Accessor
()
{}
virtual
~
Accessor
()
{}
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
HourlyEpochAccessor
::
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
return
0
;
}
void
HourlyEpochAccessor
::
next_epoch
()
{
_current_epoch_id
=
next_epoch_id
(
_current_epoch_id
);
}
std
::
string
HourlyEpochAccessor
::
text
(
int
epoch_id
)
{
return
std
::
to_string
(
epoch_id
);
}
bool
HourlyEpochAccessor
::
data_ready
(
int
epoch_id
)
{
return
true
;
}
int
HourlyEpochAccessor
::
next_epoch_id
(
int
epoch_id
)
{
if
(
epoch_id
<=
0
)
{
struct
timeval
now
;
gettimeofday
(
&
now
,
NULL
);
return
now
.
tv_sec
/
(
24
*
3600
)
*
(
24
*
3600
);
}
return
epoch_id
+
3600
;
}
bool
HourlyEpochAccessor
::
is_last_epoch
(
int
epoch_id
)
{
return
((
epoch_id
/
3600
)
%
24
)
==
23
;
}
bool
HourlyEpochAccessor
::
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
epoch_id
<=
0
)
{
return
false
;
}
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
return
true
;
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
return
is_last_epoch
(
epoch_id
);
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveTrainCheckpoint
)
{
return
((
epoch_id
/
3600
)
%
8
)
==
0
;
}
return
false
;
}
std
::
string
HourlyEpochAccessor
::
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
return
_model_root_path
+
"/xbox/delta-"
+
std
::
to_string
(
epoch_id
);
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
return
_model_root_path
+
"/xbox/base"
;
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveTrainCheckpoint
)
{
return
_model_root_path
+
"/xbox/checkpoint"
;
}
return
""
;
}
REGISTER_CLASS
(
EpochAccessor
,
HourlyEpochAccessor
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
EpochAccessor
:
public
Accessor
{
public:
EpochAccessor
()
{}
virtual
~
EpochAccessor
()
{}
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
virtual
int
current_epoch_id
()
{
return
_current_epoch_id
;
}
virtual
void
next_epoch
()
=
0
;
virtual
std
::
string
text
(
int
epoch_id
)
=
0
;
virtual
bool
data_ready
(
int
epoch_id
)
=
0
;
virtual
int
next_epoch_id
(
int
epoch_id
)
=
0
;
virtual
bool
is_last_epoch
(
int
epoch_id
)
=
0
;
virtual
bool
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
std
::
string
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
protected:
int
_current_epoch_id
;
};
REGISTER_REGISTERER
(
EpochAccessor
);
class
HourlyEpochAccessor
:
public
EpochAccessor
{
public:
HourlyEpochAccessor
()
{}
virtual
~
HourlyEpochAccessor
()
{}
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
void
next_epoch
();
virtual
std
::
string
text
(
int
epoch_id
);
virtual
bool
data_ready
(
int
epoch_id
);
virtual
int
next_epoch_id
(
int
epoch_id
);
virtual
bool
is_last_epoch
(
int
epoch_id
);
virtual
bool
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
);
virtual
std
::
string
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
);
private:
std
::
string
_model_root_path
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/accessor/me
0 → 100644
浏览文件 @
c1c5c20d
此差异已折叠。
点击以展开。
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
0 → 100644
浏览文件 @
c1c5c20d
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
//配置初始化
int
MPIRuntimeEnvironment
::
initialize
(
YAML
::
Node
config
)
{
return
0
;
}
//环境初始化,会在所有依赖模块initialize后调用
int
MPIRuntimeEnvironment
::
wireup
()
{
return
0
;
}
//当前环境rank_idx
uint32_t
MPIRuntimeEnvironment
::
rank_idx
()
{
return
0
;
}
void
MPIRuntimeEnvironment
::
barrier_all
()
{
return
;
}
void
MPIRuntimeEnvironment
::
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
{
if
(
type
==
EnvironmentLogType
::
MASTER_LOG
&&
!
is_master_node
())
{
return
;
}
VLOG
(
2
)
<<
log_str
;
return
;
}
REGISTER_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
浏览文件 @
c1c5c20d
...
...
@@ -5,28 +5,47 @@
*如:MPI环境下,写接口只允许单线程调用,那么默认对所有Env保证此调用限制
*/
#pragma once
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
enum
class
EnvironmentLogLevel
{
FATAL
=
0
,
ERROR
=
1
,
NOTICE
=
2
,
DEBUG
=
3
};
enum
class
EnvironmentLogType
{
MASTER_LOG
=
0
,
//仅master节点对外输出
ALL_LOG
=
1
//所有节点都会对外输出
};
class
RuntimeEnvironment
{
public:
RuntimeEnvironment
()
{}
virtual
~
RuntimeEnvironment
()
{}
//配置初始化
virtual
int
initialize
(
YAML
::
Node
&
config
)
=
0
;
virtual
int
initialize
(
YAML
::
Node
config
)
=
0
;
//环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
()
=
0
;
//多线程可调用接口 Start
//当前环境rank_idx
virtual
uint32_t
rank_idx
()
=
0
;
//环境内主节点
virtual
bool
is_master_node
()
{
return
rank_idx
()
==
0
;
}
//环境定制化log
template
<
class
...
ARGS
>
void
log
(
int
log_type
,
const
char
*
fmt
,
ARGS
&&
...
args
)
{
print_log
(
log_type
,
paddle
::
string
::
format_string
(
fmt
,
args
...));
void
log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
char
*
fmt
,
ARGS
&&
...
args
)
{
print_log
(
type
,
level
,
paddle
::
string
::
format_string
(
fmt
,
args
...));
}
//多线程可调用接口 End
...
...
@@ -36,19 +55,24 @@ public:
virtual
void
barrier_all
()
=
0
;
//接口只允许在主线程调用 End
protected:
virtual
void
print_log
(
int
log_type
,
const
std
::
string
&
log_str
)
=
0
;
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
=
0
;
};
REGISTER_REGISTERER
(
RuntimeEnvironment
);
class
MPIRuntimeEnvironment
:
public
RuntimeEnvironment
{
public:
MPIRuntimeEnvironment
()
{}
virtual
~
MPIRuntimeEnvironment
()
{}
//配置初始化
virtual
int
initialize
(
YAML
::
Node
&
config
)
=
0
;
virtual
int
initialize
(
YAML
::
Node
config
)
;
//环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
()
=
0
;
virtual
int
wireup
();
//当前环境rank_idx
virtual
uint32_t
rank_idx
()
=
0
;
virtual
uint32_t
rank_idx
();
virtual
void
barrier_all
();
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
);
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.h
浏览文件 @
c1c5c20d
...
...
@@ -8,13 +8,13 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
class
Execut
e
{
class
Execut
or
{
public:
Execut
e
()
{}
virtual
~
Execut
e
()
{}
Execut
or
()
{}
virtual
~
Execut
or
()
{}
//初始化,包括进行训练网络&配置加载工作
virtual
int
initialize
(
YAML
::
Node
&
exe_config
,
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
//scope 可用于填充&取 var
...
...
@@ -24,7 +24,7 @@ public:
//直接取var
template
<
class
T
>
T
*
var
(
const
std
::
string
&
name
)
{
return
_scope
.
Var
(
name
)
.
Get
<
T
>
();
return
_scope
.
Var
(
name
)
->
Get
<
T
>
();
}
template
<
class
T
>
T
*
mutable_var
(
const
std
::
string
&
name
)
{
...
...
@@ -34,20 +34,23 @@ public:
//执行n轮训练,每轮回调(epoch_id, _scope)
virtual
int
run
(
uint32_t
epoch_num
,
std
::
function
<
void
(
uint32_t
,
::
paddle
::
framework
::
Scope
*
)
>
)
=
0
;
virtual
bool
is_dump_all_model
()
{
return
false
;
}
protected:
::
paddle
::
framework
::
Scope
_scope
;
};
REGISTER_REGISTERER
(
Execut
e
);
REGISTER_REGISTERER
(
Execut
or
);
class
SimpleExecut
e
:
public
Execute
{
class
SimpleExecut
or
:
public
Executor
{
public:
SimpleExecut
e
()
{}
virtual
~
SimpleExecut
e
()
{}
virtual
int
initialize
(
YAML
::
Node
&
exe_config
,
SimpleExecut
or
()
{}
virtual
~
SimpleExecut
or
()
{}
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
int
run
(
uint32_t
epoch_num
,
std
::
function
<
void
(
uint32_t
,
::
paddle
::
framework
::
Scope
*
)
>
)
=
0
;
protected:
::
paddle
::
framework
::
Executor
_execute
;
std
::
shared_ptr
<::
paddle
::
framework
::
Executor
>
_executor
;
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/main.cc
浏览文件 @
c1c5c20d
...
...
@@ -19,12 +19,12 @@ int main(int argc, char* argv[]) {
//load trainer config
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
VLOG
(
3
)
<<
"yaml node size"
<<
trainer_context_ptr
->
trainer_config
.
size
();
std
::
vector
<
std
::
string
>
process_name_list
=
{
"InitEnvProcess"
"InitEnvProcess"
,
"LearnerProcess"
};
InitEnvProcess
init_process
;
init_process
.
run
();
for
(
const
auto
&
process_name
:
process_name_list
)
{
Process
*
process
=
CREATE_CLASS
(
Process
,
process_name
);
...
...
paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.h
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include <string>
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
//TODO 完善AucMonitor
class
AucMonitor
:
public
Monitor
{
public:
AucMonitor
()
{}
virtual
~
AucMonitor
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
Monitor
::
initialize
(
config
,
context_ptr
);
//一些额外配置 对于AUC主要是target && label 信息
return
0
;
}
//添加一项记录,统计内容Monitor自行从Executor按需获取
virtual
void
add_data
(
int
epoch_id
,
const
Executor
*
executor
);
//是否开始结果统计
virtual
bool
need_compute_result
(
int
epoch_id
,
EpochAccessor
*
accessor
);
//统计当前结果
virtual
void
compute_result
();
//基于现有结果,输出格式化的统计信息
virtual
std
::
string
format_result
();
virtual
void
reset
();
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include <string>
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
Monitor
{
public:
Monitor
()
{}
virtual
~
Monitor
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
_name
=
conf
[
"name"
].
as
<
std
::
string
>
();
return
0
;
}
//添加一项记录,统计内容Monitor自行从Executor按需获取
virtual
void
add_data
(
int
epoch_id
,
const
Executor
*
executor
)
=
0
;
//是否对于当前epoch_id进行结果统计
virtual
bool
need_compute_result
(
int
epoch_id
,
EpochAccessor
*
accessor
)
=
0
;
//统计当前结果
virtual
void
compute_result
()
=
0
;
//基于现有结果,输出格式化的统计信息
virtual
std
::
string
format_result
()
=
0
;
virtual
void
reset
()
=
0
;
const
std
::
string
&
get_name
()
{
return
_name
;
}
protected:
std
::
string
_name
;
};
REGISTER_REGISTERER
(
Monitor
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
浏览文件 @
c1c5c20d
...
...
@@ -5,6 +5,7 @@
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
namespace
paddle
{
...
...
@@ -12,12 +13,39 @@ namespace custom_trainer {
namespace
feed
{
int
InitEnvProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
Process
::
initialize
(
context_ptr
);
paddle
::
framework
::
InitDevices
(
false
);
context_ptr
->
cpu_place
=
paddle
::
platform
::
CPUPlace
();
YAML
::
Node
config
;
config
.
reset
(
_context_ptr
->
trainer_config
);
VLOG
(
3
)
<<
"yaml node size : "
<<
config
.
size
();
//environment
std
::
string
env_class
=
config
[
"environment"
][
"environment_class"
].
as
<
std
::
string
>
();
auto
*
environment
=
CREATE_CLASS
(
RuntimeEnvironment
,
env_class
);
if
(
environment
->
initialize
(
config
[
"environment"
])
!=
0
)
{
return
-
1
;
}
context_ptr
->
environment
.
reset
(
environment
);
//epoch
std
::
string
epoch_class
=
config
[
"epoch"
][
"epoch_class"
].
as
<
std
::
string
>
();
auto
*
epoch
=
CREATE_CLASS
(
EpochAccessor
,
epoch_class
);
if
(
epoch
->
initialize
(
config
[
"epoch"
],
context_ptr
)
!=
0
)
{
return
-
1
;
}
context_ptr
->
epoch_accessor
.
reset
(
epoch
);
VLOG
(
3
)
<<
"Env initialize success"
;
return
0
;
}
int
InitEnvProcess
::
run
()
{
//step 1. psserver init
//step2. psserver load
return
0
;
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/process/init_env_process.h
浏览文件 @
c1c5c20d
...
...
@@ -14,6 +14,7 @@ public:
InitEnvProcess
()
{}
virtual
~
InitEnvProcess
()
{}
virtual
int
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
int
run
();
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
0 → 100644
浏览文件 @
c1c5c20d
/*
*Author: xiexionghang
*Train样本
*/
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
ret
=
Process
::
initialize
(
context_ptr
);
auto
&
config
=
_context_ptr
->
trainer_config
;
_train_thread_num
=
config
[
"train_thread_num"
].
as
<
int
>
();
_threads_executor
.
resize
(
_train_thread_num
);
if
(
config
[
"executor"
])
{
_executor_num
=
config
[
"executor"
].
size
();
omp_set_num_threads
(
_train_thread_num
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
_threads_executor
[
i
].
resize
(
_executor_num
);
for
(
int
e
=
0
;
e
<
_executor_num
;
++
e
)
{
auto
e_class
=
config
[
"executor"
][
e
][
"class"
].
as
<
std
::
string
>
();
auto
*
e_ptr
=
CREATE_CLASS
(
Executor
,
e_class
);
_threads_executor
[
i
][
e
].
reset
(
e_ptr
);
if
(
e_ptr
->
initialize
(
config
[
"executor"
][
e
],
context_ptr
)
!=
0
)
{
ret
=
-
1
;
}
}
}
}
return
0
;
}
std
::
future
<
int
>
LearnerProcess
::
save_model
(
int
epoch_id
,
int
table_id
,
ModelSaveWay
way
)
{
std
::
promise
<
int
>
p
;
auto
ret
=
p
.
get_future
();
if
(
_context_ptr
->
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
//TODO
//context_ptr->pslib_client()->save();
}
else
{
p
.
set_value
(
0
);
}
return
ret
;
}
int
LearnerProcess
::
wait_save_model
(
int
epoch_id
,
ModelSaveWay
way
)
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
if
(
!
environment
->
is_master_node
())
{
return
0
;
}
int
ret_size
=
0
;
auto
table_num
=
_context_ptr
->
params_table_list
.
size
();
std
::
future
<
int
>
rets
[
table_num
];
for
(
int
i
=
0
;
i
<
table_num
;
++
i
)
{
auto
table_id
=
_context_ptr
->
params_table_list
[
i
].
table_id
();
rets
[
ret_size
++
]
=
save_model
(
epoch_id
,
table_id
,
way
);
}
int
all_ret
=
0
;
for
(
int
i
=
0
;
i
<
ret_size
;
++
i
)
{
rets
[
i
].
wait
();
all_ret
|=
rets
[
i
].
get
();
}
return
all_ret
;
}
int
LearnerProcess
::
run
()
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
int
epoch_id
=
epoch_accessor
->
current_epoch_id
();
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Resume traine with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//判断是否先dump出base
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier_all
();
while
(
true
)
{
epoch_accessor
->
next_epoch
();
epoch_id
=
epoch_accessor
->
current_epoch_id
();
std
::
string
epoch_log_title
=
paddle
::
string
::
format_string
(
"train epoch_id:%d label:%s"
,
epoch_id
,
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//Step1. 等待样本ready
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Start %s, wait data ready"
,
epoch_log_title
.
c_str
());
while
(
!
epoch_accessor
->
data_ready
(
epoch_id
))
{
sleep
(
30
);
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data not ready, wait 30s"
,
epoch_log_title
.
c_str
());
}
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data is ready, start traning"
,
epoch_log_title
.
c_str
());
environment
->
barrier_all
();
//Step2. 运行训练网络
bool
already_dump_inference_model
=
false
;
for
(
int
i
=
0
;
i
<
_executor_num
;
++
i
)
{
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>>
train_threads
(
_train_thread_num
);
for
(
int
thread_id
=
0
;
thread_id
<
_train_thread_num
;
++
thread_id
)
{
train_threads
[
i
].
reset
(
new
std
::
thread
([
this
](
int
exe_idx
,
int
thread_idx
)
{
auto
*
executor
=
_threads_executor
[
thread_idx
][
exe_idx
].
get
();
run_executor
(
executor
);
},
i
,
thread_id
));
}
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
train_threads
[
i
]
->
join
();
}
environment
->
barrier_all
();
if
(
_threads_executor
[
0
][
i
]
->
is_dump_all_model
())
{
already_dump_inference_model
=
true
;
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
environment
->
barrier_all
();
}
//Step3. Dump Model For Delta&&Checkpoint
if
(
!
already_dump_inference_model
)
{
already_dump_inference_model
=
true
;
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpoint
);
environment
->
barrier_all
();
//Step4. Output Monitor && RunStatus
//TODO
}
return
0
;
}
int
LearnerProcess
::
run_executor
(
Executor
*
executor
)
{
//TODO
return
0
;
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
0 → 100644
浏览文件 @
c1c5c20d
/*
*Author: xiexionghang
*Train样本
*/
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
typedef
std
::
vector
<
std
::
shared_ptr
<
Executor
>>
MultiExecutor
;
class
LearnerProcess
:
public
Process
{
public:
LearnerProcess
()
{}
virtual
~
LearnerProcess
()
{}
virtual
int
run
();
virtual
int
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
protected:
//同步保存所有模型
virtual
int
wait_save_model
(
int
epoch_id
,
ModelSaveWay
way
);
//异步保存指定模型
virtual
std
::
future
<
int
>
save_model
(
int
epoch_id
,
int
table_id
,
ModelSaveWay
way
);
//执行指定训练网络
virtual
int
run_executor
(
Executor
*
executor
);
private:
int
_executor_num
=
0
;
//需要执行训练的网络个数
int
_train_thread_num
=
1
;
//并行训练线程数
std
::
vector
<
MultiExecutor
>
_threads_executor
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/process/process.cc
浏览文件 @
c1c5c20d
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
REGISTER_CLASS
(
Process
,
InitEnvProcess
);
REGISTER_CLASS
(
Process
,
LearnerProcess
);
int
Process
::
run
()
{
return
0
;
}
...
...
paddle/fluid/train/custom_trainer/feed/process/process.h
浏览文件 @
c1c5c20d
...
...
@@ -10,8 +10,13 @@ class Process {
public:
Process
()
{}
virtual
~
Process
()
{}
virtual
int
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
virtual
int
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
_context_ptr
=
context_ptr
.
get
();
return
0
;
}
virtual
int
run
();
protected:
TrainerContext
*
_context_ptr
=
NULL
;
};
REGISTER_REGISTERER
(
Process
);
...
...
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
c1c5c20d
...
...
@@ -12,11 +12,31 @@ namespace custom_trainer {
namespace
feed
{
class
Process
;
class
EpochAccessor
;
enum
class
ModelSaveWay
{
ModelSaveTrainCheckpoint
=
0
,
ModelSaveInferenceDelta
=
1
,
ModelSaveInferenceBase
=
2
};
class
TableMeta
{
public:
TableMeta
()
{}
~
TableMeta
()
{}
int
table_id
()
{
return
_id
;
}
private:
int
_id
;
};
class
TrainerContext
{
public:
YAML
::
Node
trainer_config
;
paddle
::
platform
::
CPUPlace
cpu_place
;
std
::
vector
<
TableMeta
>
params_table_list
;
std
::
shared_ptr
<
EpochAccessor
>
epoch_accessor
;
std
::
shared_ptr
<
RuntimeEnvironment
>
environment
;
std
::
vector
<
std
::
shared_ptr
<
Process
>>
process_list
;
};
...
...
publish_include.sh
浏览文件 @
c1c5c20d
#!bash
OUTPUT_PATH
=
../../../bc_out/baidu/feed-mlarch/paddle-trainer/output/include/
INCLUDE_DIR
=
paddle/fluid/train/custom_trainer/feed/
SUB_DIR_LIST
=(
common dataset
params_access
or process shuffler
)
SUB_DIR_LIST
=(
common dataset
accessor executor monit
or process shuffler
)
rm
-rf
${
OUTPUT_PATH
}
/
${
INCLUDE_DIR
}
/
*
cp
${
INCLUDE_DIR
}
/
*
.h
${
OUTPUT_PATH
}
/
${
INCLUDE_DIR
}
/
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录