Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
c1c5c20d
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1c5c20d
编写于
8月 01, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
for runnable trainer
上级
29aec8e0
变更
19
展开全部
隐藏空白更改
内联
并排
Showing
19 changed file
with
15379 addition
and
24 deletion
+15379
-24
BCLOUD
BCLOUD
+3
-1
paddle/fluid/train/custom_trainer/feed/accessor/accessor.h
paddle/fluid/train/custom_trainer/feed/accessor/accessor.h
+19
-0
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+58
-0
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+50
-0
paddle/fluid/train/custom_trainer/feed/accessor/me
paddle/fluid/train/custom_trainer/feed/accessor/me
+14840
-0
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
...d/train/custom_trainer/feed/common/runtime_environment.cc
+34
-0
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
...id/train/custom_trainer/feed/common/runtime_environment.h
+31
-7
paddle/fluid/train/custom_trainer/feed/executor/executor.h
paddle/fluid/train/custom_trainer/feed/executor/executor.h
+14
-11
paddle/fluid/train/custom_trainer/feed/main.cc
paddle/fluid/train/custom_trainer/feed/main.cc
+3
-3
paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.h
paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.h
+38
-0
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
+46
-0
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
...uid/train/custom_trainer/feed/process/init_env_process.cc
+28
-0
paddle/fluid/train/custom_trainer/feed/process/init_env_process.h
...luid/train/custom_trainer/feed/process/init_env_process.h
+1
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+145
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
...fluid/train/custom_trainer/feed/process/learner_process.h
+40
-0
paddle/fluid/train/custom_trainer/feed/process/process.cc
paddle/fluid/train/custom_trainer/feed/process/process.cc
+2
-0
paddle/fluid/train/custom_trainer/feed/process/process.h
paddle/fluid/train/custom_trainer/feed/process/process.h
+6
-1
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+20
-0
publish_include.sh
publish_include.sh
+1
-1
未找到文件。
BCLOUD
浏览文件 @
c1c5c20d
...
...
@@ -70,10 +70,12 @@ SharedLibrary('paddle_fluid_avx_mklml', Sources(paddle_fluid_avx_mklml_src, Cpp
#feed
HEADERS('paddle/fluid/train/custom_trainer/feed/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/')
HEADERS('paddle/fluid/train/custom_trainer/feed/common/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/common/')
HEADERS('paddle/fluid/train/custom_trainer/feed/executor/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/executor/')
HEADERS('paddle/fluid/train/custom_trainer/feed/monitor/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/monitor/')
HEADERS('paddle/fluid/train/custom_trainer/feed/dataset/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/dataset/')
HEADERS('paddle/fluid/train/custom_trainer/feed/process/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/process/')
HEADERS('paddle/fluid/train/custom_trainer/feed/shuffler/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/shuffler/')
HEADERS('paddle/fluid/train/custom_trainer/feed/
params_accessor/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/params_
accessor/')
HEADERS('paddle/fluid/train/custom_trainer/feed/
accessor/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/
accessor/')
NEED_OUTPUT("baidu/third-party/mklml")
OUTPUT('paddle/fluid/train/custom_trainer/feed/conf', '$OUT')
OUTPUT('paddle/fluid/train/custom_trainer/feed/scripts', '$OUT')
...
...
paddle/fluid/train/custom_trainer/feed/accessor/accessor.h
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
Accessor
{
public:
Accessor
()
{}
virtual
~
Accessor
()
{}
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
HourlyEpochAccessor
::
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
return
0
;
}
void
HourlyEpochAccessor
::
next_epoch
()
{
_current_epoch_id
=
next_epoch_id
(
_current_epoch_id
);
}
std
::
string
HourlyEpochAccessor
::
text
(
int
epoch_id
)
{
return
std
::
to_string
(
epoch_id
);
}
bool
HourlyEpochAccessor
::
data_ready
(
int
epoch_id
)
{
return
true
;
}
int
HourlyEpochAccessor
::
next_epoch_id
(
int
epoch_id
)
{
if
(
epoch_id
<=
0
)
{
struct
timeval
now
;
gettimeofday
(
&
now
,
NULL
);
return
now
.
tv_sec
/
(
24
*
3600
)
*
(
24
*
3600
);
}
return
epoch_id
+
3600
;
}
bool
HourlyEpochAccessor
::
is_last_epoch
(
int
epoch_id
)
{
return
((
epoch_id
/
3600
)
%
24
)
==
23
;
}
bool
HourlyEpochAccessor
::
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
epoch_id
<=
0
)
{
return
false
;
}
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
return
true
;
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
return
is_last_epoch
(
epoch_id
);
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveTrainCheckpoint
)
{
return
((
epoch_id
/
3600
)
%
8
)
==
0
;
}
return
false
;
}
std
::
string
HourlyEpochAccessor
::
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
return
_model_root_path
+
"/xbox/delta-"
+
std
::
to_string
(
epoch_id
);
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
return
_model_root_path
+
"/xbox/base"
;
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveTrainCheckpoint
)
{
return
_model_root_path
+
"/xbox/checkpoint"
;
}
return
""
;
}
REGISTER_CLASS
(
EpochAccessor
,
HourlyEpochAccessor
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
EpochAccessor
:
public
Accessor
{
public:
EpochAccessor
()
{}
virtual
~
EpochAccessor
()
{}
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
virtual
int
current_epoch_id
()
{
return
_current_epoch_id
;
}
virtual
void
next_epoch
()
=
0
;
virtual
std
::
string
text
(
int
epoch_id
)
=
0
;
virtual
bool
data_ready
(
int
epoch_id
)
=
0
;
virtual
int
next_epoch_id
(
int
epoch_id
)
=
0
;
virtual
bool
is_last_epoch
(
int
epoch_id
)
=
0
;
virtual
bool
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
std
::
string
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
protected:
int
_current_epoch_id
;
};
REGISTER_REGISTERER
(
EpochAccessor
);
class
HourlyEpochAccessor
:
public
EpochAccessor
{
public:
HourlyEpochAccessor
()
{}
virtual
~
HourlyEpochAccessor
()
{}
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
void
next_epoch
();
virtual
std
::
string
text
(
int
epoch_id
);
virtual
bool
data_ready
(
int
epoch_id
);
virtual
int
next_epoch_id
(
int
epoch_id
);
virtual
bool
is_last_epoch
(
int
epoch_id
);
virtual
bool
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
);
virtual
std
::
string
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
);
private:
std
::
string
_model_root_path
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/accessor/me
0 → 100644
浏览文件 @
c1c5c20d
此差异已折叠。
点击以展开。
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
0 → 100644
浏览文件 @
c1c5c20d
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
//配置初始化
int
MPIRuntimeEnvironment
::
initialize
(
YAML
::
Node
config
)
{
return
0
;
}
//环境初始化,会在所有依赖模块initialize后调用
int
MPIRuntimeEnvironment
::
wireup
()
{
return
0
;
}
//当前环境rank_idx
uint32_t
MPIRuntimeEnvironment
::
rank_idx
()
{
return
0
;
}
void
MPIRuntimeEnvironment
::
barrier_all
()
{
return
;
}
void
MPIRuntimeEnvironment
::
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
{
if
(
type
==
EnvironmentLogType
::
MASTER_LOG
&&
!
is_master_node
())
{
return
;
}
VLOG
(
2
)
<<
log_str
;
return
;
}
REGISTER_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
浏览文件 @
c1c5c20d
...
...
@@ -5,28 +5,47 @@
*如:MPI环境下,写接口只允许单线程调用,那么默认对所有Env保证此调用限制
*/
#pragma once
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
enum
class
EnvironmentLogLevel
{
FATAL
=
0
,
ERROR
=
1
,
NOTICE
=
2
,
DEBUG
=
3
};
enum
class
EnvironmentLogType
{
MASTER_LOG
=
0
,
//仅master节点对外输出
ALL_LOG
=
1
//所有节点都会对外输出
};
class
RuntimeEnvironment
{
public:
RuntimeEnvironment
()
{}
virtual
~
RuntimeEnvironment
()
{}
//配置初始化
virtual
int
initialize
(
YAML
::
Node
&
config
)
=
0
;
virtual
int
initialize
(
YAML
::
Node
config
)
=
0
;
//环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
()
=
0
;
//多线程可调用接口 Start
//当前环境rank_idx
virtual
uint32_t
rank_idx
()
=
0
;
//环境内主节点
virtual
bool
is_master_node
()
{
return
rank_idx
()
==
0
;
}
//环境定制化log
template
<
class
...
ARGS
>
void
log
(
int
log_type
,
const
char
*
fmt
,
ARGS
&&
...
args
)
{
print_log
(
log_type
,
paddle
::
string
::
format_string
(
fmt
,
args
...));
void
log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
char
*
fmt
,
ARGS
&&
...
args
)
{
print_log
(
type
,
level
,
paddle
::
string
::
format_string
(
fmt
,
args
...));
}
//多线程可调用接口 End
...
...
@@ -36,19 +55,24 @@ public:
virtual
void
barrier_all
()
=
0
;
//接口只允许在主线程调用 End
protected:
virtual
void
print_log
(
int
log_type
,
const
std
::
string
&
log_str
)
=
0
;
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
=
0
;
};
REGISTER_REGISTERER
(
RuntimeEnvironment
);
class
MPIRuntimeEnvironment
:
public
RuntimeEnvironment
{
public:
MPIRuntimeEnvironment
()
{}
virtual
~
MPIRuntimeEnvironment
()
{}
//配置初始化
virtual
int
initialize
(
YAML
::
Node
&
config
)
=
0
;
virtual
int
initialize
(
YAML
::
Node
config
)
;
//环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
()
=
0
;
virtual
int
wireup
();
//当前环境rank_idx
virtual
uint32_t
rank_idx
()
=
0
;
virtual
uint32_t
rank_idx
();
virtual
void
barrier_all
();
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
);
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.h
浏览文件 @
c1c5c20d
...
...
@@ -8,13 +8,13 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
class
Execut
e
{
class
Execut
or
{
public:
Execut
e
()
{}
virtual
~
Execut
e
()
{}
Execut
or
()
{}
virtual
~
Execut
or
()
{}
//初始化,包括进行训练网络&配置加载工作
virtual
int
initialize
(
YAML
::
Node
&
exe_config
,
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
//scope 可用于填充&取 var
...
...
@@ -24,7 +24,7 @@ public:
//直接取var
template
<
class
T
>
T
*
var
(
const
std
::
string
&
name
)
{
return
_scope
.
Var
(
name
)
.
Get
<
T
>
();
return
_scope
.
Var
(
name
)
->
Get
<
T
>
();
}
template
<
class
T
>
T
*
mutable_var
(
const
std
::
string
&
name
)
{
...
...
@@ -34,20 +34,23 @@ public:
//执行n轮训练,每轮回调(epoch_id, _scope)
virtual
int
run
(
uint32_t
epoch_num
,
std
::
function
<
void
(
uint32_t
,
::
paddle
::
framework
::
Scope
*
)
>
)
=
0
;
virtual
bool
is_dump_all_model
()
{
return
false
;
}
protected:
::
paddle
::
framework
::
Scope
_scope
;
};
REGISTER_REGISTERER
(
Execut
e
);
REGISTER_REGISTERER
(
Execut
or
);
class
SimpleExecut
e
:
public
Execute
{
class
SimpleExecut
or
:
public
Executor
{
public:
SimpleExecut
e
()
{}
virtual
~
SimpleExecut
e
()
{}
virtual
int
initialize
(
YAML
::
Node
&
exe_config
,
SimpleExecut
or
()
{}
virtual
~
SimpleExecut
or
()
{}
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
int
run
(
uint32_t
epoch_num
,
std
::
function
<
void
(
uint32_t
,
::
paddle
::
framework
::
Scope
*
)
>
)
=
0
;
protected:
::
paddle
::
framework
::
Executor
_execute
;
std
::
shared_ptr
<::
paddle
::
framework
::
Executor
>
_executor
;
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/main.cc
浏览文件 @
c1c5c20d
...
...
@@ -19,12 +19,12 @@ int main(int argc, char* argv[]) {
//load trainer config
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
VLOG
(
3
)
<<
"yaml node size"
<<
trainer_context_ptr
->
trainer_config
.
size
();
std
::
vector
<
std
::
string
>
process_name_list
=
{
"InitEnvProcess"
"InitEnvProcess"
,
"LearnerProcess"
};
InitEnvProcess
init_process
;
init_process
.
run
();
for
(
const
auto
&
process_name
:
process_name_list
)
{
Process
*
process
=
CREATE_CLASS
(
Process
,
process_name
);
...
...
paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.h
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include <string>
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
//TODO 完善AucMonitor
class
AucMonitor
:
public
Monitor
{
public:
AucMonitor
()
{}
virtual
~
AucMonitor
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
Monitor
::
initialize
(
config
,
context_ptr
);
//一些额外配置 对于AUC主要是target && label 信息
return
0
;
}
//添加一项记录,统计内容Monitor自行从Executor按需获取
virtual
void
add_data
(
int
epoch_id
,
const
Executor
*
executor
);
//是否开始结果统计
virtual
bool
need_compute_result
(
int
epoch_id
,
EpochAccessor
*
accessor
);
//统计当前结果
virtual
void
compute_result
();
//基于现有结果,输出格式化的统计信息
virtual
std
::
string
format_result
();
virtual
void
reset
();
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
0 → 100644
浏览文件 @
c1c5c20d
#pragma once
#include <string>
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
Monitor
{
public:
Monitor
()
{}
virtual
~
Monitor
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
_name
=
conf
[
"name"
].
as
<
std
::
string
>
();
return
0
;
}
//添加一项记录,统计内容Monitor自行从Executor按需获取
virtual
void
add_data
(
int
epoch_id
,
const
Executor
*
executor
)
=
0
;
//是否对于当前epoch_id进行结果统计
virtual
bool
need_compute_result
(
int
epoch_id
,
EpochAccessor
*
accessor
)
=
0
;
//统计当前结果
virtual
void
compute_result
()
=
0
;
//基于现有结果,输出格式化的统计信息
virtual
std
::
string
format_result
()
=
0
;
virtual
void
reset
()
=
0
;
const
std
::
string
&
get_name
()
{
return
_name
;
}
protected:
std
::
string
_name
;
};
REGISTER_REGISTERER
(
Monitor
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
浏览文件 @
c1c5c20d
...
...
@@ -5,6 +5,7 @@
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
namespace
paddle
{
...
...
@@ -12,12 +13,39 @@ namespace custom_trainer {
namespace
feed
{
int
InitEnvProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
Process
::
initialize
(
context_ptr
);
paddle
::
framework
::
InitDevices
(
false
);
context_ptr
->
cpu_place
=
paddle
::
platform
::
CPUPlace
();
YAML
::
Node
config
;
config
.
reset
(
_context_ptr
->
trainer_config
);
VLOG
(
3
)
<<
"yaml node size : "
<<
config
.
size
();
//environment
std
::
string
env_class
=
config
[
"environment"
][
"environment_class"
].
as
<
std
::
string
>
();
auto
*
environment
=
CREATE_CLASS
(
RuntimeEnvironment
,
env_class
);
if
(
environment
->
initialize
(
config
[
"environment"
])
!=
0
)
{
return
-
1
;
}
context_ptr
->
environment
.
reset
(
environment
);
//epoch
std
::
string
epoch_class
=
config
[
"epoch"
][
"epoch_class"
].
as
<
std
::
string
>
();
auto
*
epoch
=
CREATE_CLASS
(
EpochAccessor
,
epoch_class
);
if
(
epoch
->
initialize
(
config
[
"epoch"
],
context_ptr
)
!=
0
)
{
return
-
1
;
}
context_ptr
->
epoch_accessor
.
reset
(
epoch
);
VLOG
(
3
)
<<
"Env initialize success"
;
return
0
;
}
int
InitEnvProcess
::
run
()
{
//step 1. psserver init
//step2. psserver load
return
0
;
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/process/init_env_process.h
浏览文件 @
c1c5c20d
...
...
@@ -14,6 +14,7 @@ public:
InitEnvProcess
()
{}
virtual
~
InitEnvProcess
()
{}
virtual
int
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
int
run
();
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
0 → 100644
浏览文件 @
c1c5c20d
/*
*Author: xiexionghang
*Train样本
*/
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
ret
=
Process
::
initialize
(
context_ptr
);
auto
&
config
=
_context_ptr
->
trainer_config
;
_train_thread_num
=
config
[
"train_thread_num"
].
as
<
int
>
();
_threads_executor
.
resize
(
_train_thread_num
);
if
(
config
[
"executor"
])
{
_executor_num
=
config
[
"executor"
].
size
();
omp_set_num_threads
(
_train_thread_num
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
_threads_executor
[
i
].
resize
(
_executor_num
);
for
(
int
e
=
0
;
e
<
_executor_num
;
++
e
)
{
auto
e_class
=
config
[
"executor"
][
e
][
"class"
].
as
<
std
::
string
>
();
auto
*
e_ptr
=
CREATE_CLASS
(
Executor
,
e_class
);
_threads_executor
[
i
][
e
].
reset
(
e_ptr
);
if
(
e_ptr
->
initialize
(
config
[
"executor"
][
e
],
context_ptr
)
!=
0
)
{
ret
=
-
1
;
}
}
}
}
return
0
;
}
std
::
future
<
int
>
LearnerProcess
::
save_model
(
int
epoch_id
,
int
table_id
,
ModelSaveWay
way
)
{
std
::
promise
<
int
>
p
;
auto
ret
=
p
.
get_future
();
if
(
_context_ptr
->
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
//TODO
//context_ptr->pslib_client()->save();
}
else
{
p
.
set_value
(
0
);
}
return
ret
;
}
int
LearnerProcess
::
wait_save_model
(
int
epoch_id
,
ModelSaveWay
way
)
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
if
(
!
environment
->
is_master_node
())
{
return
0
;
}
int
ret_size
=
0
;
auto
table_num
=
_context_ptr
->
params_table_list
.
size
();
std
::
future
<
int
>
rets
[
table_num
];
for
(
int
i
=
0
;
i
<
table_num
;
++
i
)
{
auto
table_id
=
_context_ptr
->
params_table_list
[
i
].
table_id
();
rets
[
ret_size
++
]
=
save_model
(
epoch_id
,
table_id
,
way
);
}
int
all_ret
=
0
;
for
(
int
i
=
0
;
i
<
ret_size
;
++
i
)
{
rets
[
i
].
wait
();
all_ret
|=
rets
[
i
].
get
();
}
return
all_ret
;
}
int
LearnerProcess
::
run
()
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
int
epoch_id
=
epoch_accessor
->
current_epoch_id
();
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Resume traine with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//判断是否先dump出base
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier_all
();
while
(
true
)
{
epoch_accessor
->
next_epoch
();
epoch_id
=
epoch_accessor
->
current_epoch_id
();
std
::
string
epoch_log_title
=
paddle
::
string
::
format_string
(
"train epoch_id:%d label:%s"
,
epoch_id
,
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//Step1. 等待样本ready
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Start %s, wait data ready"
,
epoch_log_title
.
c_str
());
while
(
!
epoch_accessor
->
data_ready
(
epoch_id
))
{
sleep
(
30
);
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data not ready, wait 30s"
,
epoch_log_title
.
c_str
());
}
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data is ready, start traning"
,
epoch_log_title
.
c_str
());
environment
->
barrier_all
();
//Step2. 运行训练网络
bool
already_dump_inference_model
=
false
;
for
(
int
i
=
0
;
i
<
_executor_num
;
++
i
)
{
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>>
train_threads
(
_train_thread_num
);
for
(
int
thread_id
=
0
;
thread_id
<
_train_thread_num
;
++
thread_id
)
{
train_threads
[
i
].
reset
(
new
std
::
thread
([
this
](
int
exe_idx
,
int
thread_idx
)
{
auto
*
executor
=
_threads_executor
[
thread_idx
][
exe_idx
].
get
();
run_executor
(
executor
);
},
i
,
thread_id
));
}
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
train_threads
[
i
]
->
join
();
}
environment
->
barrier_all
();
if
(
_threads_executor
[
0
][
i
]
->
is_dump_all_model
())
{
already_dump_inference_model
=
true
;
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
environment
->
barrier_all
();
}
//Step3. Dump Model For Delta&&Checkpoint
if
(
!
already_dump_inference_model
)
{
already_dump_inference_model
=
true
;
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpoint
);
environment
->
barrier_all
();
//Step4. Output Monitor && RunStatus
//TODO
}
return
0
;
}
int
LearnerProcess
::
run_executor
(
Executor
*
executor
)
{
//TODO
return
0
;
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
0 → 100644
浏览文件 @
c1c5c20d
/*
*Author: xiexionghang
*Train样本
*/
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
typedef
std
::
vector
<
std
::
shared_ptr
<
Executor
>>
MultiExecutor
;
class
LearnerProcess
:
public
Process
{
public:
LearnerProcess
()
{}
virtual
~
LearnerProcess
()
{}
virtual
int
run
();
virtual
int
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
protected:
//同步保存所有模型
virtual
int
wait_save_model
(
int
epoch_id
,
ModelSaveWay
way
);
//异步保存指定模型
virtual
std
::
future
<
int
>
save_model
(
int
epoch_id
,
int
table_id
,
ModelSaveWay
way
);
//执行指定训练网络
virtual
int
run_executor
(
Executor
*
executor
);
private:
int
_executor_num
=
0
;
//需要执行训练的网络个数
int
_train_thread_num
=
1
;
//并行训练线程数
std
::
vector
<
MultiExecutor
>
_threads_executor
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/process/process.cc
浏览文件 @
c1c5c20d
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
REGISTER_CLASS
(
Process
,
InitEnvProcess
);
REGISTER_CLASS
(
Process
,
LearnerProcess
);
int
Process
::
run
()
{
return
0
;
}
...
...
paddle/fluid/train/custom_trainer/feed/process/process.h
浏览文件 @
c1c5c20d
...
...
@@ -10,8 +10,13 @@ class Process {
public:
Process
()
{}
virtual
~
Process
()
{}
virtual
int
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
virtual
int
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
_context_ptr
=
context_ptr
.
get
();
return
0
;
}
virtual
int
run
();
protected:
TrainerContext
*
_context_ptr
=
NULL
;
};
REGISTER_REGISTERER
(
Process
);
...
...
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
c1c5c20d
...
...
@@ -12,11 +12,31 @@ namespace custom_trainer {
namespace
feed
{
class
Process
;
class
EpochAccessor
;
enum
class
ModelSaveWay
{
ModelSaveTrainCheckpoint
=
0
,
ModelSaveInferenceDelta
=
1
,
ModelSaveInferenceBase
=
2
};
class
TableMeta
{
public:
TableMeta
()
{}
~
TableMeta
()
{}
int
table_id
()
{
return
_id
;
}
private:
int
_id
;
};
class
TrainerContext
{
public:
YAML
::
Node
trainer_config
;
paddle
::
platform
::
CPUPlace
cpu_place
;
std
::
vector
<
TableMeta
>
params_table_list
;
std
::
shared_ptr
<
EpochAccessor
>
epoch_accessor
;
std
::
shared_ptr
<
RuntimeEnvironment
>
environment
;
std
::
vector
<
std
::
shared_ptr
<
Process
>>
process_list
;
};
...
...
publish_include.sh
浏览文件 @
c1c5c20d
#!bash
OUTPUT_PATH
=
../../../bc_out/baidu/feed-mlarch/paddle-trainer/output/include/
INCLUDE_DIR
=
paddle/fluid/train/custom_trainer/feed/
SUB_DIR_LIST
=(
common dataset
params_access
or process shuffler
)
SUB_DIR_LIST
=(
common dataset
accessor executor monit
or process shuffler
)
rm
-rf
${
OUTPUT_PATH
}
/
${
INCLUDE_DIR
}
/
*
cp
${
INCLUDE_DIR
}
/
*
.h
${
OUTPUT_PATH
}
/
${
INCLUDE_DIR
}
/
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录