Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleRec
提交
c1d64d7e
P
PaddleRec
项目概览
PaddlePaddle
/
PaddleRec
通知
68
Star
12
Fork
5
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
27
列表
看板
标记
里程碑
合并请求
10
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
27
Issue
27
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1d64d7e
编写于
8月 09, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
for dataset pipeline
上级
8699c196
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
570 addition
and
102 deletion
+570
-102
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+14
-8
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+20
-14
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
+131
-0
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
...d/train/custom_trainer/feed/common/runtime_environment.cc
+75
-15
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
...id/train/custom_trainer/feed/common/runtime_environment.h
+29
-24
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+12
-2
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
+66
-0
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
+44
-0
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
...id/train/custom_trainer/feed/dataset/dataset_container.cc
+130
-13
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h
...uid/train/custom_trainer/feed/dataset/dataset_container.h
+44
-21
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+3
-3
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
...fluid/train/custom_trainer/feed/process/learner_process.h
+2
-2
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
c1d64d7e
...
@@ -11,25 +11,31 @@ namespace feed {
...
@@ -11,25 +11,31 @@ namespace feed {
void
HourlyEpochAccessor
::
next_epoch
()
{
void
HourlyEpochAccessor
::
next_epoch
()
{
_current_epoch_id
=
next_epoch_id
(
_current_epoch_id
);
_current_epoch_id
=
next_epoch_id
(
_current_epoch_id
);
}
}
std
::
string
HourlyEpochAccessor
::
text
(
in
t
epoch_id
)
{
std
::
string
HourlyEpochAccessor
::
text
(
uint64_
t
epoch_id
)
{
return
std
::
to_string
(
epoch_id
);
return
std
::
to_string
(
epoch_id
);
}
}
bool
HourlyEpochAccessor
::
data_ready
(
in
t
epoch_id
)
{
bool
HourlyEpochAccessor
::
data_ready
(
uint64_
t
epoch_id
)
{
return
true
;
return
true
;
}
}
int
HourlyEpochAccessor
::
next_epoch_id
(
in
t
epoch_id
)
{
int
HourlyEpochAccessor
::
next_epoch_id
(
uint64_
t
epoch_id
)
{
if
(
epoch_id
<
=
0
)
{
if
(
epoch_id
=
=
0
)
{
struct
timeval
now
;
struct
timeval
now
;
gettimeofday
(
&
now
,
NULL
);
gettimeofday
(
&
now
,
NULL
);
return
now
.
tv_sec
/
(
24
*
3600
)
*
(
24
*
3600
);
return
now
.
tv_sec
/
(
24
*
3600
)
*
(
24
*
3600
);
}
}
return
epoch_id
+
3600
;
return
epoch_id
+
3600
;
}
}
bool
HourlyEpochAccessor
::
is_last_epoch
(
in
t
epoch_id
)
{
bool
HourlyEpochAccessor
::
is_last_epoch
(
uint64_
t
epoch_id
)
{
return
((
epoch_id
/
3600
)
%
24
)
==
23
;
return
((
epoch_id
/
3600
)
%
24
)
==
23
;
}
}
bool
HourlyEpochAccessor
::
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
)
{
uint64_t
HourlyEpochAccessor
::
epoch_time_interval
()
{
if
(
epoch_id
<=
0
)
{
return
3600
;
}
uint64_t
HourlyEpochAccessor
::
epoch_timestamp
(
uint64_t
epoch_id
)
{
return
epoch_id
;
}
bool
HourlyEpochAccessor
::
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
epoch_id
==
0
)
{
return
false
;
return
false
;
}
}
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
...
@@ -41,7 +47,7 @@ namespace feed {
...
@@ -41,7 +47,7 @@ namespace feed {
}
}
return
false
;
return
false
;
}
}
std
::
string
HourlyEpochAccessor
::
model_save_path
(
in
t
epoch_id
,
ModelSaveWay
save_way
)
{
std
::
string
HourlyEpochAccessor
::
model_save_path
(
uint64_
t
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
return
_model_root_path
+
"/xbox/delta-"
+
std
::
to_string
(
epoch_id
);
return
_model_root_path
+
"/xbox/delta-"
+
std
::
to_string
(
epoch_id
);
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
浏览文件 @
c1d64d7e
...
@@ -13,18 +13,22 @@ public:
...
@@ -13,18 +13,22 @@ public:
virtual
int
initialize
(
YAML
::
Node
config
,
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
virtual
in
t
current_epoch_id
()
{
virtual
uint64_
t
current_epoch_id
()
{
return
_current_epoch_id
;
return
_current_epoch_id
;
}
}
virtual
void
next_epoch
()
=
0
;
virtual
void
next_epoch
()
=
0
;
virtual
std
::
string
text
(
int
epoch_id
)
=
0
;
virtual
std
::
string
text
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
data_ready
(
int
epoch_id
)
=
0
;
virtual
bool
data_ready
(
uint64_t
epoch_id
)
=
0
;
virtual
int
next_epoch_id
(
int
epoch_id
)
=
0
;
virtual
int
next_epoch_id
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
is_last_epoch
(
int
epoch_id
)
=
0
;
virtual
bool
is_last_epoch
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
//epoch间的数据时间间隔(秒)
virtual
std
::
string
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
uint64_t
epoch_time_interval
()
=
0
;
//获取epoch的样本数据时间
virtual
uint64_t
epoch_timestamp
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
std
::
string
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
protected:
protected:
in
t
_current_epoch_id
;
uint64_
t
_current_epoch_id
;
};
};
REGISTER_REGISTERER
(
EpochAccessor
);
REGISTER_REGISTERER
(
EpochAccessor
);
...
@@ -35,12 +39,14 @@ public:
...
@@ -35,12 +39,14 @@ public:
virtual
int
initialize
(
YAML
::
Node
config
,
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
void
next_epoch
();
virtual
void
next_epoch
();
virtual
std
::
string
text
(
int
epoch_id
);
virtual
std
::
string
text
(
uint64_t
epoch_id
);
virtual
bool
data_ready
(
int
epoch_id
);
virtual
bool
data_ready
(
uint64_t
epoch_id
);
virtual
int
next_epoch_id
(
int
epoch_id
);
virtual
int
next_epoch_id
(
uint64_t
epoch_id
);
virtual
bool
is_last_epoch
(
int
epoch_id
);
virtual
bool
is_last_epoch
(
uint64_t
epoch_id
);
virtual
bool
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
);
virtual
uint64_t
epoch_time_interval
();
virtual
std
::
string
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
);
virtual
uint64_t
epoch_timestamp
(
uint64_t
epoch_id
);
virtual
bool
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
);
virtual
std
::
string
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
);
private:
private:
std
::
string
_model_root_path
;
std
::
string
_model_root_path
;
};
};
...
...
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
0 → 100644
浏览文件 @
c1d64d7e
#pragma once
#include "paddle/fluid/framework/archive.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
PipelineOptions
{
public:
PipelineOptions
()
=
default
;
uint32_t
buffer_data_num
=
400
;
//缓冲区数据个数,需大于batch_size
uint32_t
batch_size
=
100
;
//从pipe读数据的batch大小
bool
need_hold_input_data
=
false
;
//是否保存input流数据,否则消费后释放
};
/*
* 数据流管道,管道内可对流入数据进行格式转换,再流出
*
* |---------------Pipeline---------------|
* Channel<IN> -> Converter -> Channel<OUT>
* 多个管道可通过connect_to方法进行级联
*
* 使用initialize 或 connect_to 初始化管道
*/
template
<
class
TypeIn
,
class
TypeOut
>
class
Pipeline
{
public:
Pipeline
()
{}
Pipeline
(
Pipeline
&&
)
=
delete
;
Pipeline
(
const
Pipeline
&
)
=
delete
;
typedef
std
::
function
<
int
(
const
TypeIn
*
,
TypeOut
*
,
size_t
num
)
>
PipeDataConverter
;
int
initialize
(
const
PipelineOptions
&
options
,
::
paddle
::
framework
::
Channel
<
TypeIn
>
input_channel
,
PipeDataConverter
data_converter
)
{
CHECK
(
_inited
==
false
);
CHECK
(
options
.
batch_size
>
0
);
_inited
=
true
;
_options
=
options
;
_is_read_end
=
false
;
_converter
=
data_converter
;
_input_channel
=
input_channel
;
_output_channel
=
::
paddle
::
framework
::
MakeChannel
<
TypeOut
>
();
auto
batch_size
=
options
.
batch_size
;
auto
buffer_data_num
=
options
.
buffer_data_num
;
_input_channel
->
SetBlockSize
(
batch_size
);
_output_channel
->
SetBlockSize
(
batch_size
);
_input_data_buffer
.
resize
(
buffer_data_num
);
_output_data_buffer
.
resize
(
buffer_data_num
);
if
(
buffer_data_num
/
batch_size
<
3
)
{
buffer_data_num
=
batch_size
*
3
;
}
buffer_data_num
=
(
buffer_data_num
/
batch_size
)
*
batch_size
;
_output_channel
->
SetCapacity
(
buffer_data_num
);
CHECK
(
_input_channel
!=
nullptr
)
<<
" Input Channel is null"
;
_convert_thread
=
std
::
make_shared
<
std
::
thread
>
([
this
](){
async_convert_data
();
});
return
0
;
}
template
<
class
PreTypeIn
>
int
connect_to
(
Pipeline
<
PreTypeIn
,
TypeIn
>&
pre_pipeline
,
PipeDataConverter
data_converter
)
{
return
initialize
(
pre_pipeline
.
options
(),
pre_pipeline
.
output_chnnel
(),
data_converter
);
}
virtual
~
Pipeline
()
{
_is_read_end
=
true
;
if
(
_convert_thread
!=
nullptr
)
{
_convert_thread
->
join
();
}
}
inline
size_t
read
(
std
::
vector
<
TypeOut
>&
p
)
{
p
.
clear
();
size_t
num
=
_output_channel
->
Read
(
p
);
return
num
;
}
inline
const
PipelineOptions
&
options
()
{
return
_options
;
}
inline
::
paddle
::
framework
::
Channel
<
TypeOut
>
output_chnnel
()
{
return
_output_channel
;
}
private:
void
async_convert_data
()
{
size_t
convete_batch_size
=
_input_data_buffer
.
size
()
/
4
;
if
(
convete_batch_size
<
_options
.
batch_size
*
3
)
{
convete_batch_size
=
3
*
_options
.
batch_size
;
}
convete_batch_size
=
(
convete_batch_size
/
_options
.
batch_size
)
*
_options
.
batch_size
;
while
(
!
_is_read_end
)
{
while
(
_output_channel
->
Size
()
<
_input_data_buffer
.
size
())
{
size_t
read_size
=
_input_channel
->
Read
(
convete_batch_size
,
&
_input_data_buffer
[
0
]);
if
(
read_size
==
0
)
{
_is_read_end
=
true
;
break
;
}
CHECK
(
_converter
(
&
_input_data_buffer
[
0
],
&
_output_data_buffer
[
0
],
read_size
)
==
0
)
<<
"Data Converter Do Failed"
;
_output_channel
->
WriteMove
(
read_size
,
&
_output_data_buffer
[
0
]);
if
(
_options
.
need_hold_input_data
)
{
_input_channel_backup
->
WriteMove
(
read_size
,
&
_input_data_buffer
[
0
]);
}
}
sleep
(
1
);
}
}
private:
bool
_inited
=
false
;
//标识初始化状态
bool
_is_read_end
=
false
;
//标识输入流读取完成
PipelineOptions
_options
;
//pipe参数
PipeDataConverter
_converter
;
//converter
std
::
vector
<
TypeIn
>
_input_data_buffer
;
//输入数据buffer
std
::
vector
<
TypeOut
>
_output_data_buffer
;
//出数据buffer
std
::
shared_ptr
<
std
::
thread
>
_convert_thread
;
//异步convert
::
paddle
::
framework
::
Channel
<
TypeIn
>
_input_channel
;
//输入流
::
paddle
::
framework
::
Channel
<
TypeIn
>
_input_channel_backup
;
//备份原始输入流
::
paddle
::
framework
::
Channel
<
TypeOut
>
_output_channel
;
//输出流
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
浏览文件 @
c1d64d7e
...
@@ -4,29 +4,89 @@ namespace paddle {
...
@@ -4,29 +4,89 @@ namespace paddle {
namespace
custom_trainer
{
namespace
custom_trainer
{
namespace
feed
{
namespace
feed
{
//配置初始化
RuntimeEnvironment
::
RuntimeEnvironment
()
{}
int
MPIRuntimeEnvironment
::
initialize
(
YAML
::
Node
config
)
{
RuntimeEnvironment
::~
RuntimeEnvironment
()
{}
return
0
;
bool
RuntimeEnvironment
::
is_master_node
(
EnvironmentRole
role
)
{
return
rank_id
(
role
)
==
0
;
}
std
::
string
format_timestamp
(
time_t
time
,
const
char
*
format
)
{
std
::
string
result
;
struct
tm
p
=
*
localtime
(
&
time
);
char
time_str_buffer
[
64
];
int
size
=
strftime
(
time_str_buffer
,
64
,
format
,
&
p
);
if
(
size
>
0
)
{
result
.
assign
(
time_str_buffer
,
size
);
}
}
//环境初始化,会在所有依赖模块initialize后调用
return
result
;
int
MPIRuntimeEnvironment
::
wireup
()
{
}
struct
MpiNodeInfo
{
int
rank_id
=
-
1
;
int
node_num
=
0
;
MPI_Comm
mpi_comm
;
};
class
MPIRuntimeEnvironment
:
public
RuntimeEnvironment
{
public:
MPIRuntimeEnvironment
()
{}
virtual
~
MPIRuntimeEnvironment
()
{}
virtual
int
initialize
(
YAML
::
Node
config
)
{
return
0
;
return
0
;
}
}
//当前环境rank_idx
virtual
int
wireup
()
{
uint32_t
MPIRuntimeEnvironment
::
rank_idx
()
{
int
hr
=
MPI_Init
(
NULL
,
NULL
);
if
(
MPI_SUCCESS
!=
hr
)
{
LOG
(
FATAL
)
<<
"MPI_init failed with error code"
<<
hr
;
return
-
1
;
}
_roles_node_info
.
resize
(
static_cast
<
int
>
(
EnvironmentRole
::
ALL
)
+
1
);
set_role
(
EnvironmentRole
::
ALL
);
return
0
;
return
0
;
}
}
void
MPIRuntimeEnvironment
::
barrier_all
()
{
return
;
virtual
uint32_t
rank_id
(
EnvironmentRole
role
)
{
return
mpi_node_info
(
role
).
rank_id
;
}
virtual
uint32_t
node_num
(
EnvironmentRole
role
)
{
return
mpi_node_info
(
role
).
node_num
;
}
}
void
MPIRuntimeEnvironment
::
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
{
virtual
int
set_role
(
EnvironmentRole
role
)
{
if
(
type
==
EnvironmentLogType
::
MASTER_LOG
&&
!
is_master_node
())
{
auto
&
node_info
=
mpi_node_info
(
role
);
return
;
if
(
node_info
.
rank_id
<
0
)
{
if
(
role
==
EnvironmentRole
::
ALL
)
{
node_info
.
mpi_comm
=
MPI_COMM_WORLD
;
}
else
{
MPI_Comm_split
(
MPI_COMM_WORLD
,
static_cast
<
int
>
(
role
),
mpi_node_info
(
EnvironmentRole
::
ALL
).
rank_id
,
&
(
node_info
.
mpi_comm
));
}
MPI_Comm_rank
(
node_info
.
mpi_comm
,
&
(
node_info
.
rank_id
));
MPI_Comm_size
(
node_info
.
mpi_comm
,
&
(
node_info
.
node_num
));
}
}
VLOG
(
2
)
<<
log_str
;
return
0
;
return
;
}
}
REGISTER_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
virtual
void
barrier
(
EnvironmentRole
role
)
{
MPI_Barrier
(
mpi_node_info
(
role
).
mpi_comm
);
}
virtual
void
bcast
(
paddle
::
framework
::
BinaryArchive
&
ar
,
int
root_id
,
EnvironmentRole
role
)
{
auto
&
node_info
=
mpi_node_info
(
role
);
int
len
=
(
int
)
ar
.
length
();
MPI_Bcast
(
&
len
,
1
,
MPI_INT
,
root_id
,
node_info
.
mpi_comm
);
ar
.
resize
(
len
);
ar
.
set_cursor
(
ar
.
buffer
());
MPI_Bcast
(
ar
.
buffer
(),
len
,
MPI_BYTE
,
root
,
node_info
.
mpi_comm
);
}
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
);
inline
MpiNodeInfo
&
mpi_node_info
(
EnvironmentRole
role
)
{
return
_roles_node_info
[
static_cast
<
int
>
(
role
)];
}
private:
std
::
vector
<
MpiNodeInfo
>
_roles_node_info
;
};
REGISTER_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
}
// namespace feed
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
浏览文件 @
c1d64d7e
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
*/
*/
#pragma once
#pragma once
#include <yaml-cpp/yaml.h>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
...
@@ -21,26 +22,37 @@ enum class EnvironmentLogLevel {
...
@@ -21,26 +22,37 @@ enum class EnvironmentLogLevel {
};
};
enum
class
EnvironmentLogType
{
enum
class
EnvironmentLogType
{
MASTER_LOG
=
0
,
//仅master节点对外输出
MASTER_LOG
=
0
,
//仅master节点对外输出
ALL_LOG
=
1
//所有节点都会对外输出
ALL_LOG
=
1
//所有节点都会对外输出
};
//保持该枚举值的连续递增,且ALL在尾部
enum
class
EnvironmentRole
{
WORKER
=
0
,
//训练Worker
PSERVER
=
1
,
//参数服务器
ALL
=
2
//所有角色,请保持在枚举尾部
};
};
class
RuntimeEnvironment
{
class
RuntimeEnvironment
{
public:
public:
RuntimeEnvironment
()
{}
RuntimeEnvironment
()
;
virtual
~
RuntimeEnvironment
()
{}
virtual
~
RuntimeEnvironment
()
;
//配置初始化
//配置初始化
virtual
int
initialize
(
YAML
::
Node
config
)
=
0
;
virtual
int
initialize
(
YAML
::
Node
config
)
=
0
;
//设置role
virtual
int
set_role
(
EnvironmentRole
role
)
=
0
;
//环境初始化,会在所有依赖模块initialize后调用
//环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
()
=
0
;
virtual
int
wireup
()
=
0
;
//多线程可调用接口 Start
//多线程可调用接口 Start
//当前环境rank_idx
//当前环境rank_idx
virtual
uint32_t
rank_idx
()
=
0
;
virtual
uint32_t
rank_id
(
EnvironmentRole
role
)
=
0
;
//运行环境节点数
virtual
uint32_t
node_num
(
EnvironmentRole
role
)
=
0
;
//环境内主节点
//环境内主节点
virtual
bool
is_master_node
()
{
virtual
bool
is_master_node
(
EnvironmentRole
role
);
return
rank_idx
()
==
0
;
}
//环境定制化log
//环境定制化log
template
<
class
...
ARGS
>
template
<
class
...
ARGS
>
void
log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
void
log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
...
@@ -51,29 +63,22 @@ public:
...
@@ -51,29 +63,22 @@ public:
//接口只允许在主线程调用 Start
//接口只允许在主线程调用 Start
//barrier
//barrier 指定role的节点
virtual
void
barrier_all
()
=
0
;
virtual
void
barrier
(
EnvironmentRole
role
)
=
0
;
//bcast 广播
virtual
void
bcast
(
paddle
::
framework
::
BinaryArchive
&
ar
,
int
root_id
,
EnvironmentRole
role
)
=
0
;
//接口只允许在主线程调用 End
//接口只允许在主线程调用 End
protected:
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
=
0
;
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
=
0
;
};
};
REGISTER_REGISTERER
(
RuntimeEnvironment
);
REGISTER_REGISTERER
(
RuntimeEnvironment
);
class
MPIRuntimeEnvironment
:
public
RuntimeEnvironment
{
public:
MPIRuntimeEnvironment
()
{}
virtual
~
MPIRuntimeEnvironment
()
{}
//配置初始化
virtual
int
initialize
(
YAML
::
Node
config
);
//环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
();
//当前环境rank_idx
virtual
uint32_t
rank_idx
();
virtual
void
barrier_all
();
protected:
std
::
string
format_timestamp
(
time_t
time
,
const
char
*
format
);
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
);
std
::
string
format_timestamp
(
time_t
time
,
const
std
::
string
&
format
)
{
};
return
format_timestamp
(
time
,
format
.
c_str
());
}
}
// namespace feed
}
// namespace feed
}
// namespace custom_trainer
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
c1d64d7e
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <memory>
#include <memory>
#include <yaml-cpp/yaml.h>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -36,6 +37,11 @@ public:
...
@@ -36,6 +37,11 @@ public:
std
::
string
data
;
//样本数据, maybe压缩格式
std
::
string
data
;
//样本数据, maybe压缩格式
};
};
typedef
std
::
shared_ptr
<
Pipeline
<
DataItem
,
SampleInstance
>>
SampleInstancePipe
;
inline
SampleInstancePipe
make_sample_instance_channel
()
{
return
std
::
make_shared
<
Pipeline
<
DataItem
,
SampleInstance
>>
();
}
class
DataParser
{
class
DataParser
{
public:
public:
DataParser
()
{}
DataParser
()
{}
...
@@ -56,8 +62,12 @@ public:
...
@@ -56,8 +62,12 @@ public:
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
//判断样本数据是否已就绪,就绪表明可以开始download
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
=
0
;
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
=
0
;
//读取数据样本流中
//读取dir下文件列表
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
=
0
;
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
);
//读取目录下数据到样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>&
data_channel
)
=
0
;
//读取指定文件列表的数据到样本流中
virtual
int
read_all
(
const
std
::
vector
<
std
::
string
>&
data_list
,
::
paddle
::
framework
::
Channel
<
DataItem
>&
data_channel
)
=
0
;
virtual
const
DataParser
*
get_parser
()
{
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
return
_parser
.
get
();
}
}
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
0 → 100644
浏览文件 @
c1d64d7e
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
Dataset
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
if
(
!
config
[
"data_list"
])
{
VLOG
(
0
)
<<
"miss data_list config in dataset, please check"
;
return
-
1
;
}
int
data_num
=
config
[
"data_list"
].
size
();
for
(
int
i
=
0
;
i
<
data_num
;
++
i
)
{
std
::
string
name
=
config
[
"data_list"
][
i
][
"name"
].
as
<
std
::
string
>
();
auto
data_ptr
=
std
::
make_shared
<
DatasetContainer
>
();
if
(
data_ptr
->
initialize
(
config
[
"data_list"
][
i
],
context
)
!=
0
)
{
VLOG
(
0
)
<<
"dataset initialize failed, name:"
<<
name
;
return
-
1
;
}
_data_containers
[
name
]
=
data_ptr
;
}
return
0
;
}
inline
void
Dataset
::
pre_detect_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
)
{
_data_containers
[
data_name
]
->
pre_detect_data
(
epoch_id
);
return
;
}
inline
DatasetStatus
Dataset
::
epoch_data_status
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
)
{
return
_data_containers
[
data_name
]
->
epoch_data_status
(
epoch_id
);
}
inline
::
paddle
::
framework
::
Channel
<
DataItem
>
Dataset
::
fetch_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
)
{
return
_data_containers
[
data_name
]
->
fetch
(
epoch_id
);
}
SampleInstancePipe
Dataset
::
fetch_sample
(
const
std
::
string
&
data_name
,
uint32_t
batch_size
,
uint64_t
epoch_id
)
{
auto
*
data_container
=
_data_containers
[
data_name
].
get
();
auto
data_channel
=
data_container
->
fetch
(
epoch_id
);
const
auto
*
data_parser
=
data_container
->
data_parser
();
PipelineOptions
options
;
options
.
batch_size
=
batch_size
;
options
.
need_hold_input_data
=
true
;
options
.
buffer_data_num
=
batch_size
*
10
;
SampleInstancePipe
pipe
=
make_sample_instance_channel
();
pipe
->
initialize
(
options
,
data_channel
,
[
data_parser
]
(
const
DataItem
*
data
,
SampleInstance
*
sample
,
size_t
num
)
->
int
{
int
ret
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
,
++
data
,
++
sample
)
{
ret
|=
data_parser
->
parse_to_sample
(
*
data
,
*
sample
);
}
return
ret
;
});
return
pipe
;
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
0 → 100644
浏览文件 @
c1d64d7e
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
Dataset
{
public:
Dataset
()
{}
virtual
~
Dataset
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
);
//触发可预取的数据判断
virtual
void
pre_detect_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
);
//获取数据状态
virtual
DatasetStatus
epoch_data_status
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
);
//返回各DataContainer内的原始数据(maybe 压缩格式)
virtual
::
paddle
::
framework
::
Channel
<
DataItem
>
fetch_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
);
//以管道形式返回标准样本流,管道内会对数据做异步转换
virtual
SampleInstancePipe
fetch_sample
(
const
std
::
string
&
data_name
,
uint32_t
batch_size
,
uint64_t
epoch_id
);
private:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
DatasetContainer
>>
_data_containers
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
浏览文件 @
c1d64d7e
...
@@ -8,31 +8,148 @@
...
@@ -8,31 +8,148 @@
#include <yaml-cpp/yaml.h>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace
paddle
{
namespace
paddle
{
namespace
custom_trainer
{
namespace
custom_trainer
{
namespace
feed
{
namespace
feed
{
int
DatasetContainer
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
_dataset_config
=
config
;
_trainer_context
=
context
.
get
();
//预取n轮样本数据
_prefetch_num
=
config
[
"prefetch_num"
].
as
<
int
>
();
_dataset_list
.
resize
(
_prefetch_num
);
_data_root_paths
=
paddle
::
string
::
split_string
(
config
[
"root_path"
].
as
<
std
::
string
>
(),
" "
);
_data_split_interval
=
config
[
"data_spit_interval"
].
as
<
int
>
();
_data_path_formater
=
config
[
"data_path_formater"
].
as
<
std
::
string
>
();
std
::
string
data_reader_class
=
config
[
"data_reader"
].
as
<
std
::
string
>
();
DataReader
*
data_reader
=
CREATE_CLASS
(
DataReader
,
data_reader_class
);
_data_reader
.
reset
(
data_reader
);
return
_data_reader
->
initialize
(
config
,
context
);
}
std
::
shared_ptr
<
DatasetInfo
>
DatasetContainer
::
dataset
(
uint64_t
timestamp
)
{
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
auto
data_idx
=
timestamp
/
epoch_accessor
->
epoch_time_interval
();
return
_dataset_list
[
data_idx
%
_prefetch_num
];
}
void
DatasetContainer
::
pre_detect_data
(
uint64_t
epoch_id
)
{
int
status
=
0
;
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
time_t
timestamp
=
epoch_accessor
->
epoch_timestamp
(
epoch_id
);
if
(
timestamp
%
epoch_accessor
->
epoch_time_interval
()
!=
0
)
{
LOG
(
FATAL
)
<<
"timestamp:"
<<
timestamp
<<
" don't match interval:"
<<
epoch_accessor
->
epoch_time_interval
();
return
;
}
size_t
data_num
=
data_num_for_train
(
timestamp
,
epoch_accessor
->
epoch_time_interval
(),
_data_split_interval
);
uint64_t
data_timestamp
=
timestamp
%
_data_split_interval
==
0
?
timestamp
:
(
timestamp
/
_data_split_interval
+
1
)
*
_data_split_interval
;
std
::
vector
<
std
::
string
>
data_path_list
;
for
(
int
i
=
0
;
i
<
_data_root_paths
.
size
()
&&
status
==
0
;
++
i
)
{
for
(
int
j
=
0
;
j
<
data_num
&&
status
==
0
;
++
j
)
{
std
::
string
path_suffix
=
format_timestamp
(
data_timestamp
+
j
*
_data_split_interval
,
_data_path_formater
);
std
::
string
data_dir
=
_data_root_paths
[
i
]
+
"/"
+
path_suffix
;
status
=
read_data_list
(
data_dir
,
data_path_list
);
}
}
if
(
status
==
0
)
{
auto
dataset_info
=
dataset
(
timestamp
);
dataset_info
->
timestamp
=
timestamp
;
dataset_info
->
file_path_list
=
std
::
move
(
data_path_list
);
dataset_info
->
status
=
DatasetStatus
::
Detected
;
}
return
;
}
int
DatasetContainer
::
read_data_list
(
const
std
::
string
&
data_dir
,
std
::
vector
<
std
::
string
>&
data_list
)
{
auto
*
environment
=
_trainer_context
->
environment
.
get
();
// 检查数据Ready
int
data_status
=
-
1
;
if
(
environment
->
is_master_node
(
EnvironmentRole
::
WORKER
))
{
if
(
_data_reader
->
is_data_ready
(
data_dir
))
{
data_status
=
0
;
}
}
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
data_status
;
environment
->
bcast
(
ar
,
0
,
EnvironmentRole
::
WORKER
);
ar
>>
data_status
;
if
(
data_status
!=
0
)
{
return
-
1
;
}
// 读取文件列表
ar
.
Clear
();
std
::
vector
<
std
::
string
>
data_path_list
;
if
(
environment
->
is_master_node
(
EnvironmentRole
::
WORKER
))
{
data_path_list
=
_data_reader
->
data_file_list
(
data_dir
);
ar
<<
data_path_list
;
}
environment
->
bcast
(
ar
,
0
,
EnvironmentRole
::
WORKER
);
ar
>>
data_path_list
;
auto
worker_id
=
environment
->
rank_id
(
EnvironmentRole
::
WORKER
);
auto
worker_num
=
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
for
(
int
i
=
worker_id
;
i
<
data_path_list
.
size
();
i
+=
worker_num
)
{
data_list
.
push_back
(
data_path_list
[
i
]);
}
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
return
0
;
}
DatasetStatus
DatasetContainer
::
epoch_data_status
(
uint64_t
epoch_id
)
{
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
time_t
timestamp
=
epoch_accessor
->
epoch_timestamp
(
epoch_id
);
return
data_status
(
timestamp
);
}
DatasetStatus
DatasetContainer
::
data_status
(
uint64_t
timestamp
)
{
auto
dataset_info
=
dataset
(
timestamp
);
if
(
dataset_info
->
timestamp
!=
timestamp
)
{
return
DatasetStatus
::
Empty
;
}
return
dataset_info
->
status
;
}
paddle
::
framework
::
Channel
<
DataItem
>
DatasetContainer
::
fetch
(
in
t
epoch_id
)
{
paddle
::
framework
::
Channel
<
DataItem
>
DatasetContainer
::
fetch
(
uint64_
t
epoch_id
)
{
paddle
::
framework
::
Channel
<
DataItem
>
result
;
paddle
::
framework
::
Channel
<
DataItem
>
result
;
if
(
_ready_epoch_id
<
epoch_id
)
{
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
time_t
timestamp
=
epoch_accessor
->
epoch_timestamp
(
epoch_id
);
if
(
data_status
(
timestamp
)
!=
DatasetStatus
::
Ready
)
{
return
result
;
return
result
;
}
}
_current_epoch_id
=
epoch_id
;
auto
dataset_info
=
dataset
(
timestamp
);
_current_dataset_idx
=
epoch_id
%
_prefetch_num
;
return
dataset_info
->
data_channel
;
//result = _dataset_list[_current_dataset_idx].fetch();
//_dataset_list[_current_dataset_idx].reset((decltype(result.get())*)NULL);
return
result
;
}
}
void
DatasetContainer
::
async_download_data
()
{
void
DatasetContainer
::
async_download_data
(
uint64_t
start_timestamp
)
{
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
if
(
start_timestamp
%
epoch_accessor
->
epoch_time_interval
()
!=
0
)
{
LOG
(
FATAL
)
<<
"timestamp:"
<<
start_timestamp
<<
" don't match interval:"
<<
epoch_accessor
->
epoch_time_interval
();
return
;
}
while
(
true
)
{
while
(
true
)
{
//do download
auto
dataset_info
=
dataset
(
start_timestamp
);
sleep
(
30
);
while
(
data_status
(
start_timestamp
)
!=
DatasetStatus
::
Detected
)
{
sleep
(
30
);
}
const
auto
&
file_list
=
dataset_info
->
file_path_list
;
dataset_info
->
data_channel
->
Clear
();
while
(
_data_reader
->
read_all
(
file_list
,
dataset_info
->
data_channel
)
!=
0
)
{
dataset_info
->
data_channel
->
Clear
();
VLOG
(
0
)
<<
"timestamp:"
<<
start_timestamp
<<
" data read failed, retry"
;
sleep
(
30
);
}
start_timestamp
+=
epoch_accessor
->
epoch_time_interval
();
}
}
}
}
}
//
namespace feed
}
//
namespace feed
}
//
namespace custom_trainer
}
//
namespace custom_trainer
}
//
namespace paddle
}
//
namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h
浏览文件 @
c1d64d7e
...
@@ -16,38 +16,61 @@ namespace paddle {
...
@@ -16,38 +16,61 @@ namespace paddle {
namespace
custom_trainer
{
namespace
custom_trainer
{
namespace
feed
{
namespace
feed
{
inline
int
data_num_for_train
(
uint64_t
train_begin_timestamp
,
uint32_t
train_time_interval
,
uint32_t
data_time_interval
)
{
uint64_t
data_begin_time
=
train_begin_timestamp
;
uint64_t
data_end_time
=
data_begin_time
+
train_time_interval
;
uint64_t
end_idx
=
(
data_end_time
+
data_time_interval
-
1
)
/
data_time_interval
;
uint64_t
begin_idx
=
(
data_begin_time
+
data_time_interval
-
1
)
/
data_time_interval
;
return
end_idx
-
begin_idx
;
}
enum
class
DatasetStatus
{
Empty
=
0
,
Detected
=
1
,
Downloding
=
2
,
Ready
=
3
};
struct
DatasetInfo
{
uint64_t
timestamp
=
0
;
std
::
vector
<
std
::
string
>
file_path_list
;
DatasetStatus
status
=
DatasetStatus
::
Empty
;
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
=
::
paddle
::
framework
::
MakeChannel
<
DataItem
>
();
};
class
DatasetContainer
{
class
DatasetContainer
{
public:
public:
DatasetContainer
()
{}
DatasetContainer
()
{}
virtual
~
DatasetContainer
()
{}
virtual
~
DatasetContainer
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
)
{
virtual
int
initialize
(
_dataset_config
=
config
;
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
);
//预取n轮样本数据
_prefetch_num
=
config
[
"prefetch_num"
].
as
<
int
>
();
_data_root_path
=
config
[
"root_path"
].
as
<
std
::
string
>
();
_data_path_generater
=
config
[
"_data_path_generater"
].
as
<
std
::
string
>
();
return
0
;
}
virtual
void
run
();
virtual
void
run
();
//获取特定epoch_i样本,如果数据未ready,Channel内为空指针
virtual
::
paddle
::
framework
::
Channel
<
DataItem
>
fetch
(
int
epoch_id
);
//触发可预取的数据判断
//触发可预取的数据判断
virtual
void
pre_detect_data
(
RuntimeEnvironment
*
env
);
virtual
void
pre_detect_data
(
uint64_t
epoch_id
);
//获取数据状态
virtual
DatasetStatus
epoch_data_status
(
uint64_t
epoch_id
);
//获取特定epoch_i样本,如果数据未ready,Channel内为空指针
virtual
::
paddle
::
framework
::
Channel
<
DataItem
>
fetch
(
uint64_t
epoch_id
);
//获取DataItem解析器
virtual
const
DataParser
*
data_parser
()
{
return
_data_reader
->
get_parser
();
}
protected:
protected:
virtual
DatasetStatus
data_status
(
uint64_t
timestamp
);
virtual
int
read_data_list
(
const
std
::
string
&
data_dir
,
std
::
vector
<
std
::
string
>&
data_list
);
//异步样本download
//异步样本download
virtual
void
async_download_data
();
virtual
void
async_download_data
(
uint64_t
start_timestamp
);
virtual
void
download
(
int
epoch_id
,
const
std
::
vector
<
std
::
string
>&
paths
);
virtual
std
::
shared_ptr
<
DatasetInfo
>
dataset
(
uint64_t
timestamp
);
int
_prefetch_num
=
0
;
int
_prefetch_num
=
0
;
int
_data_split_interval
=
60
;
//样本切分周期(秒)
YAML
::
Node
_dataset_config
;
YAML
::
Node
_dataset_config
;
std
::
string
_data_
root_path
;
std
::
string
_data_
path_formater
;
std
::
string
_data_path_generater
;
std
::
vector
<
std
::
string
>
_data_root_paths
;
//支持同时读取多个目录
uint32_t
_current_dataset_idx
;
//当前样本数据idx
TrainerContext
*
_trainer_context
;
int
_current_epoch_id
=
-
1
;
std
::
shared_ptr
<
DataReader
>
_data_reader
;
int
_ready_epoch_id
=
-
1
;
//已下载完成的epoch_id
std
::
shared_ptr
<
std
::
thread
>
_downloader_thread
;
std
::
vector
<
std
::
shared_ptr
<
::
paddle
::
framework
::
Dataset
>>
_dataset_list
;
//预取的数据列表
std
::
vector
<
std
::
shared_ptr
<
DatasetInfo
>>
_dataset_list
;
//预取的数据列表
};
};
}
//namespace feed
}
//namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
c1d64d7e
...
@@ -35,7 +35,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
...
@@ -35,7 +35,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return
0
;
return
0
;
}
}
std
::
future
<
int
>
LearnerProcess
::
save_model
(
in
t
epoch_id
,
int
table_id
,
ModelSaveWay
way
)
{
std
::
future
<
int
>
LearnerProcess
::
save_model
(
uint64_
t
epoch_id
,
int
table_id
,
ModelSaveWay
way
)
{
std
::
promise
<
int
>
p
;
std
::
promise
<
int
>
p
;
auto
ret
=
p
.
get_future
();
auto
ret
=
p
.
get_future
();
if
(
_context_ptr
->
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
if
(
_context_ptr
->
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
...
@@ -47,7 +47,7 @@ std::future<int> LearnerProcess::save_model(int epoch_id, int table_id, ModelSav
...
@@ -47,7 +47,7 @@ std::future<int> LearnerProcess::save_model(int epoch_id, int table_id, ModelSav
return
ret
;
return
ret
;
}
}
int
LearnerProcess
::
wait_save_model
(
in
t
epoch_id
,
ModelSaveWay
way
)
{
int
LearnerProcess
::
wait_save_model
(
uint64_
t
epoch_id
,
ModelSaveWay
way
)
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
if
(
!
environment
->
is_master_node
())
{
if
(
!
environment
->
is_master_node
())
{
return
0
;
return
0
;
...
@@ -71,7 +71,7 @@ int LearnerProcess::wait_save_model(int epoch_id, ModelSaveWay way) {
...
@@ -71,7 +71,7 @@ int LearnerProcess::wait_save_model(int epoch_id, ModelSaveWay way) {
int
LearnerProcess
::
run
()
{
int
LearnerProcess
::
run
()
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
in
t
epoch_id
=
epoch_accessor
->
current_epoch_id
();
uint64_
t
epoch_id
=
epoch_accessor
->
current_epoch_id
();
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Resume traine with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
"Resume traine with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
浏览文件 @
c1d64d7e
...
@@ -21,9 +21,9 @@ public:
...
@@ -21,9 +21,9 @@ public:
protected:
protected:
//同步保存所有模型
//同步保存所有模型
virtual
int
wait_save_model
(
in
t
epoch_id
,
ModelSaveWay
way
);
virtual
int
wait_save_model
(
uint64_
t
epoch_id
,
ModelSaveWay
way
);
//异步保存指定模型
//异步保存指定模型
virtual
std
::
future
<
int
>
save_model
(
in
t
epoch_id
,
int
table_id
,
ModelSaveWay
way
);
virtual
std
::
future
<
int
>
save_model
(
uint64_
t
epoch_id
,
int
table_id
,
ModelSaveWay
way
);
//执行指定训练网络
//执行指定训练网络
virtual
int
run_executor
(
Executor
*
executor
);
virtual
int
run_executor
(
Executor
*
executor
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录