Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
c1d64d7e
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1d64d7e
编写于
8月 09, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
for dataset pipeline
上级
8699c196
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
570 addition
and
102 deletion
+570
-102
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+14
-8
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+20
-14
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
+131
-0
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
...d/train/custom_trainer/feed/common/runtime_environment.cc
+75
-15
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
...id/train/custom_trainer/feed/common/runtime_environment.h
+29
-24
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+12
-2
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
+66
-0
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
+44
-0
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
...id/train/custom_trainer/feed/dataset/dataset_container.cc
+130
-13
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h
...uid/train/custom_trainer/feed/dataset/dataset_container.h
+44
-21
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+3
-3
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
...fluid/train/custom_trainer/feed/process/learner_process.h
+2
-2
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
c1d64d7e
...
...
@@ -11,25 +11,31 @@ namespace feed {
void
HourlyEpochAccessor
::
next_epoch
()
{
_current_epoch_id
=
next_epoch_id
(
_current_epoch_id
);
}
std
::
string
HourlyEpochAccessor
::
text
(
in
t
epoch_id
)
{
std
::
string
HourlyEpochAccessor
::
text
(
uint64_
t
epoch_id
)
{
return
std
::
to_string
(
epoch_id
);
}
bool
HourlyEpochAccessor
::
data_ready
(
in
t
epoch_id
)
{
bool
HourlyEpochAccessor
::
data_ready
(
uint64_
t
epoch_id
)
{
return
true
;
}
int
HourlyEpochAccessor
::
next_epoch_id
(
in
t
epoch_id
)
{
if
(
epoch_id
<
=
0
)
{
int
HourlyEpochAccessor
::
next_epoch_id
(
uint64_
t
epoch_id
)
{
if
(
epoch_id
=
=
0
)
{
struct
timeval
now
;
gettimeofday
(
&
now
,
NULL
);
return
now
.
tv_sec
/
(
24
*
3600
)
*
(
24
*
3600
);
}
return
epoch_id
+
3600
;
}
bool
HourlyEpochAccessor
::
is_last_epoch
(
in
t
epoch_id
)
{
bool
HourlyEpochAccessor
::
is_last_epoch
(
uint64_
t
epoch_id
)
{
return
((
epoch_id
/
3600
)
%
24
)
==
23
;
}
bool
HourlyEpochAccessor
::
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
epoch_id
<=
0
)
{
uint64_t
HourlyEpochAccessor
::
epoch_time_interval
()
{
return
3600
;
}
uint64_t
HourlyEpochAccessor
::
epoch_timestamp
(
uint64_t
epoch_id
)
{
return
epoch_id
;
}
bool
HourlyEpochAccessor
::
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
epoch_id
==
0
)
{
return
false
;
}
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
...
...
@@ -41,7 +47,7 @@ namespace feed {
}
return
false
;
}
std
::
string
HourlyEpochAccessor
::
model_save_path
(
in
t
epoch_id
,
ModelSaveWay
save_way
)
{
std
::
string
HourlyEpochAccessor
::
model_save_path
(
uint64_
t
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
return
_model_root_path
+
"/xbox/delta-"
+
std
::
to_string
(
epoch_id
);
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
浏览文件 @
c1d64d7e
...
...
@@ -13,18 +13,22 @@ public:
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
virtual
in
t
current_epoch_id
()
{
virtual
uint64_
t
current_epoch_id
()
{
return
_current_epoch_id
;
}
virtual
void
next_epoch
()
=
0
;
virtual
std
::
string
text
(
int
epoch_id
)
=
0
;
virtual
bool
data_ready
(
int
epoch_id
)
=
0
;
virtual
int
next_epoch_id
(
int
epoch_id
)
=
0
;
virtual
bool
is_last_epoch
(
int
epoch_id
)
=
0
;
virtual
bool
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
std
::
string
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
std
::
string
text
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
data_ready
(
uint64_t
epoch_id
)
=
0
;
virtual
int
next_epoch_id
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
is_last_epoch
(
uint64_t
epoch_id
)
=
0
;
//epoch间的数据时间间隔(秒)
virtual
uint64_t
epoch_time_interval
()
=
0
;
//获取epoch的样本数据时间
virtual
uint64_t
epoch_timestamp
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
std
::
string
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
protected:
in
t
_current_epoch_id
;
uint64_
t
_current_epoch_id
;
};
REGISTER_REGISTERER
(
EpochAccessor
);
...
...
@@ -35,12 +39,14 @@ public:
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
void
next_epoch
();
virtual
std
::
string
text
(
int
epoch_id
);
virtual
bool
data_ready
(
int
epoch_id
);
virtual
int
next_epoch_id
(
int
epoch_id
);
virtual
bool
is_last_epoch
(
int
epoch_id
);
virtual
bool
need_save_model
(
int
epoch_id
,
ModelSaveWay
save_way
);
virtual
std
::
string
model_save_path
(
int
epoch_id
,
ModelSaveWay
save_way
);
virtual
std
::
string
text
(
uint64_t
epoch_id
);
virtual
bool
data_ready
(
uint64_t
epoch_id
);
virtual
int
next_epoch_id
(
uint64_t
epoch_id
);
virtual
bool
is_last_epoch
(
uint64_t
epoch_id
);
virtual
uint64_t
epoch_time_interval
();
virtual
uint64_t
epoch_timestamp
(
uint64_t
epoch_id
);
virtual
bool
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
);
virtual
std
::
string
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
);
private:
std
::
string
_model_root_path
;
};
...
...
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
0 → 100644
浏览文件 @
c1d64d7e
#pragma once
#include "paddle/fluid/framework/archive.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
PipelineOptions
{
public:
PipelineOptions
()
=
default
;
uint32_t
buffer_data_num
=
400
;
//缓冲区数据个数,需大于batch_size
uint32_t
batch_size
=
100
;
//从pipe读数据的batch大小
bool
need_hold_input_data
=
false
;
//是否保存input流数据,否则消费后释放
};
/*
* 数据流管道,管道内可对流入数据进行格式转换,再流出
*
* |---------------Pipeline---------------|
* Channel<IN> -> Converter -> Channel<OUT>
* 多个管道可通过connect_to方法进行级联
*
* 使用initialize 或 connect_to 初始化管道
*/
template
<
class
TypeIn
,
class
TypeOut
>
class
Pipeline
{
public:
Pipeline
()
{}
Pipeline
(
Pipeline
&&
)
=
delete
;
Pipeline
(
const
Pipeline
&
)
=
delete
;
typedef
std
::
function
<
int
(
const
TypeIn
*
,
TypeOut
*
,
size_t
num
)
>
PipeDataConverter
;
int
initialize
(
const
PipelineOptions
&
options
,
::
paddle
::
framework
::
Channel
<
TypeIn
>
input_channel
,
PipeDataConverter
data_converter
)
{
CHECK
(
_inited
==
false
);
CHECK
(
options
.
batch_size
>
0
);
_inited
=
true
;
_options
=
options
;
_is_read_end
=
false
;
_converter
=
data_converter
;
_input_channel
=
input_channel
;
_output_channel
=
::
paddle
::
framework
::
MakeChannel
<
TypeOut
>
();
auto
batch_size
=
options
.
batch_size
;
auto
buffer_data_num
=
options
.
buffer_data_num
;
_input_channel
->
SetBlockSize
(
batch_size
);
_output_channel
->
SetBlockSize
(
batch_size
);
_input_data_buffer
.
resize
(
buffer_data_num
);
_output_data_buffer
.
resize
(
buffer_data_num
);
if
(
buffer_data_num
/
batch_size
<
3
)
{
buffer_data_num
=
batch_size
*
3
;
}
buffer_data_num
=
(
buffer_data_num
/
batch_size
)
*
batch_size
;
_output_channel
->
SetCapacity
(
buffer_data_num
);
CHECK
(
_input_channel
!=
nullptr
)
<<
" Input Channel is null"
;
_convert_thread
=
std
::
make_shared
<
std
::
thread
>
([
this
](){
async_convert_data
();
});
return
0
;
}
template
<
class
PreTypeIn
>
int
connect_to
(
Pipeline
<
PreTypeIn
,
TypeIn
>&
pre_pipeline
,
PipeDataConverter
data_converter
)
{
return
initialize
(
pre_pipeline
.
options
(),
pre_pipeline
.
output_chnnel
(),
data_converter
);
}
virtual
~
Pipeline
()
{
_is_read_end
=
true
;
if
(
_convert_thread
!=
nullptr
)
{
_convert_thread
->
join
();
}
}
inline
size_t
read
(
std
::
vector
<
TypeOut
>&
p
)
{
p
.
clear
();
size_t
num
=
_output_channel
->
Read
(
p
);
return
num
;
}
inline
const
PipelineOptions
&
options
()
{
return
_options
;
}
inline
::
paddle
::
framework
::
Channel
<
TypeOut
>
output_chnnel
()
{
return
_output_channel
;
}
private:
void
async_convert_data
()
{
size_t
convete_batch_size
=
_input_data_buffer
.
size
()
/
4
;
if
(
convete_batch_size
<
_options
.
batch_size
*
3
)
{
convete_batch_size
=
3
*
_options
.
batch_size
;
}
convete_batch_size
=
(
convete_batch_size
/
_options
.
batch_size
)
*
_options
.
batch_size
;
while
(
!
_is_read_end
)
{
while
(
_output_channel
->
Size
()
<
_input_data_buffer
.
size
())
{
size_t
read_size
=
_input_channel
->
Read
(
convete_batch_size
,
&
_input_data_buffer
[
0
]);
if
(
read_size
==
0
)
{
_is_read_end
=
true
;
break
;
}
CHECK
(
_converter
(
&
_input_data_buffer
[
0
],
&
_output_data_buffer
[
0
],
read_size
)
==
0
)
<<
"Data Converter Do Failed"
;
_output_channel
->
WriteMove
(
read_size
,
&
_output_data_buffer
[
0
]);
if
(
_options
.
need_hold_input_data
)
{
_input_channel_backup
->
WriteMove
(
read_size
,
&
_input_data_buffer
[
0
]);
}
}
sleep
(
1
);
}
}
private:
bool
_inited
=
false
;
//标识初始化状态
bool
_is_read_end
=
false
;
//标识输入流读取完成
PipelineOptions
_options
;
//pipe参数
PipeDataConverter
_converter
;
//converter
std
::
vector
<
TypeIn
>
_input_data_buffer
;
//输入数据buffer
std
::
vector
<
TypeOut
>
_output_data_buffer
;
//出数据buffer
std
::
shared_ptr
<
std
::
thread
>
_convert_thread
;
//异步convert
::
paddle
::
framework
::
Channel
<
TypeIn
>
_input_channel
;
//输入流
::
paddle
::
framework
::
Channel
<
TypeIn
>
_input_channel_backup
;
//备份原始输入流
::
paddle
::
framework
::
Channel
<
TypeOut
>
_output_channel
;
//输出流
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
浏览文件 @
c1d64d7e
...
...
@@ -4,29 +4,89 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
//配置初始化
int
MPIRuntimeEnvironment
::
initialize
(
YAML
::
Node
config
)
{
return
0
;
RuntimeEnvironment
::
RuntimeEnvironment
()
{}
RuntimeEnvironment
::~
RuntimeEnvironment
()
{}
bool
RuntimeEnvironment
::
is_master_node
(
EnvironmentRole
role
)
{
return
rank_id
(
role
)
==
0
;
}
std
::
string
format_timestamp
(
time_t
time
,
const
char
*
format
)
{
std
::
string
result
;
struct
tm
p
=
*
localtime
(
&
time
);
char
time_str_buffer
[
64
];
int
size
=
strftime
(
time_str_buffer
,
64
,
format
,
&
p
);
if
(
size
>
0
)
{
result
.
assign
(
time_str_buffer
,
size
);
}
//环境初始化,会在所有依赖模块initialize后调用
int
MPIRuntimeEnvironment
::
wireup
()
{
return
result
;
}
struct
MpiNodeInfo
{
int
rank_id
=
-
1
;
int
node_num
=
0
;
MPI_Comm
mpi_comm
;
};
class
MPIRuntimeEnvironment
:
public
RuntimeEnvironment
{
public:
MPIRuntimeEnvironment
()
{}
virtual
~
MPIRuntimeEnvironment
()
{}
virtual
int
initialize
(
YAML
::
Node
config
)
{
return
0
;
}
//当前环境rank_idx
uint32_t
MPIRuntimeEnvironment
::
rank_idx
()
{
virtual
int
wireup
()
{
int
hr
=
MPI_Init
(
NULL
,
NULL
);
if
(
MPI_SUCCESS
!=
hr
)
{
LOG
(
FATAL
)
<<
"MPI_init failed with error code"
<<
hr
;
return
-
1
;
}
_roles_node_info
.
resize
(
static_cast
<
int
>
(
EnvironmentRole
::
ALL
)
+
1
);
set_role
(
EnvironmentRole
::
ALL
);
return
0
;
}
void
MPIRuntimeEnvironment
::
barrier_all
()
{
return
;
virtual
uint32_t
rank_id
(
EnvironmentRole
role
)
{
return
mpi_node_info
(
role
).
rank_id
;
}
virtual
uint32_t
node_num
(
EnvironmentRole
role
)
{
return
mpi_node_info
(
role
).
node_num
;
}
void
MPIRuntimeEnvironment
::
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
{
if
(
type
==
EnvironmentLogType
::
MASTER_LOG
&&
!
is_master_node
())
{
return
;
virtual
int
set_role
(
EnvironmentRole
role
)
{
auto
&
node_info
=
mpi_node_info
(
role
);
if
(
node_info
.
rank_id
<
0
)
{
if
(
role
==
EnvironmentRole
::
ALL
)
{
node_info
.
mpi_comm
=
MPI_COMM_WORLD
;
}
else
{
MPI_Comm_split
(
MPI_COMM_WORLD
,
static_cast
<
int
>
(
role
),
mpi_node_info
(
EnvironmentRole
::
ALL
).
rank_id
,
&
(
node_info
.
mpi_comm
));
}
MPI_Comm_rank
(
node_info
.
mpi_comm
,
&
(
node_info
.
rank_id
));
MPI_Comm_size
(
node_info
.
mpi_comm
,
&
(
node_info
.
node_num
));
}
VLOG
(
2
)
<<
log_str
;
return
;
return
0
;
}
REGISTER_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
virtual
void
barrier
(
EnvironmentRole
role
)
{
MPI_Barrier
(
mpi_node_info
(
role
).
mpi_comm
);
}
virtual
void
bcast
(
paddle
::
framework
::
BinaryArchive
&
ar
,
int
root_id
,
EnvironmentRole
role
)
{
auto
&
node_info
=
mpi_node_info
(
role
);
int
len
=
(
int
)
ar
.
length
();
MPI_Bcast
(
&
len
,
1
,
MPI_INT
,
root_id
,
node_info
.
mpi_comm
);
ar
.
resize
(
len
);
ar
.
set_cursor
(
ar
.
buffer
());
MPI_Bcast
(
ar
.
buffer
(),
len
,
MPI_BYTE
,
root
,
node_info
.
mpi_comm
);
}
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
);
inline
MpiNodeInfo
&
mpi_node_info
(
EnvironmentRole
role
)
{
return
_roles_node_info
[
static_cast
<
int
>
(
role
)];
}
private:
std
::
vector
<
MpiNodeInfo
>
_roles_node_info
;
};
REGISTER_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
浏览文件 @
c1d64d7e
...
...
@@ -6,6 +6,7 @@
*/
#pragma once
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
...
...
@@ -21,26 +22,37 @@ enum class EnvironmentLogLevel {
};
enum
class
EnvironmentLogType
{
MASTER_LOG
=
0
,
//仅master节点对外输出
ALL_LOG
=
1
//所有节点都会对外输出
MASTER_LOG
=
0
,
//仅master节点对外输出
ALL_LOG
=
1
//所有节点都会对外输出
};
//保持该枚举值的连续递增,且ALL在尾部
enum
class
EnvironmentRole
{
WORKER
=
0
,
//训练Worker
PSERVER
=
1
,
//参数服务器
ALL
=
2
//所有角色,请保持在枚举尾部
};
class
RuntimeEnvironment
{
public:
RuntimeEnvironment
()
{}
virtual
~
RuntimeEnvironment
()
{}
RuntimeEnvironment
()
;
virtual
~
RuntimeEnvironment
()
;
//配置初始化
virtual
int
initialize
(
YAML
::
Node
config
)
=
0
;
//设置role
virtual
int
set_role
(
EnvironmentRole
role
)
=
0
;
//环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
()
=
0
;
//多线程可调用接口 Start
//当前环境rank_idx
virtual
uint32_t
rank_idx
()
=
0
;
virtual
uint32_t
rank_id
(
EnvironmentRole
role
)
=
0
;
//运行环境节点数
virtual
uint32_t
node_num
(
EnvironmentRole
role
)
=
0
;
//环境内主节点
virtual
bool
is_master_node
()
{
return
rank_idx
()
==
0
;
}
virtual
bool
is_master_node
(
EnvironmentRole
role
);
//环境定制化log
template
<
class
...
ARGS
>
void
log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
...
...
@@ -51,29 +63,22 @@ public:
//接口只允许在主线程调用 Start
//barrier
virtual
void
barrier_all
()
=
0
;
//barrier 指定role的节点
virtual
void
barrier
(
EnvironmentRole
role
)
=
0
;
//bcast 广播
virtual
void
bcast
(
paddle
::
framework
::
BinaryArchive
&
ar
,
int
root_id
,
EnvironmentRole
role
)
=
0
;
//接口只允许在主线程调用 End
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
=
0
;
};
REGISTER_REGISTERER
(
RuntimeEnvironment
);
class
MPIRuntimeEnvironment
:
public
RuntimeEnvironment
{
public:
MPIRuntimeEnvironment
()
{}
virtual
~
MPIRuntimeEnvironment
()
{}
//配置初始化
virtual
int
initialize
(
YAML
::
Node
config
);
//环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
();
//当前环境rank_idx
virtual
uint32_t
rank_idx
();
virtual
void
barrier_all
();
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
);
};
std
::
string
format_timestamp
(
time_t
time
,
const
char
*
format
);
std
::
string
format_timestamp
(
time_t
time
,
const
std
::
string
&
format
)
{
return
format_timestamp
(
time
,
format
.
c_str
());
}
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
c1d64d7e
...
...
@@ -8,6 +8,7 @@
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
namespace
paddle
{
...
...
@@ -36,6 +37,11 @@ public:
std
::
string
data
;
//样本数据, maybe压缩格式
};
typedef
std
::
shared_ptr
<
Pipeline
<
DataItem
,
SampleInstance
>>
SampleInstancePipe
;
inline
SampleInstancePipe
make_sample_instance_channel
()
{
return
std
::
make_shared
<
Pipeline
<
DataItem
,
SampleInstance
>>
();
}
class
DataParser
{
public:
DataParser
()
{}
...
...
@@ -56,8 +62,12 @@ public:
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
=
0
;
//读取数据样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
=
0
;
//读取dir下文件列表
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
);
//读取目录下数据到样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>&
data_channel
)
=
0
;
//读取指定文件列表的数据到样本流中
virtual
int
read_all
(
const
std
::
vector
<
std
::
string
>&
data_list
,
::
paddle
::
framework
::
Channel
<
DataItem
>&
data_channel
)
=
0
;
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
}
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
0 → 100644
浏览文件 @
c1d64d7e
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
Dataset
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
if
(
!
config
[
"data_list"
])
{
VLOG
(
0
)
<<
"miss data_list config in dataset, please check"
;
return
-
1
;
}
int
data_num
=
config
[
"data_list"
].
size
();
for
(
int
i
=
0
;
i
<
data_num
;
++
i
)
{
std
::
string
name
=
config
[
"data_list"
][
i
][
"name"
].
as
<
std
::
string
>
();
auto
data_ptr
=
std
::
make_shared
<
DatasetContainer
>
();
if
(
data_ptr
->
initialize
(
config
[
"data_list"
][
i
],
context
)
!=
0
)
{
VLOG
(
0
)
<<
"dataset initialize failed, name:"
<<
name
;
return
-
1
;
}
_data_containers
[
name
]
=
data_ptr
;
}
return
0
;
}
inline
void
Dataset
::
pre_detect_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
)
{
_data_containers
[
data_name
]
->
pre_detect_data
(
epoch_id
);
return
;
}
inline
DatasetStatus
Dataset
::
epoch_data_status
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
)
{
return
_data_containers
[
data_name
]
->
epoch_data_status
(
epoch_id
);
}
inline
::
paddle
::
framework
::
Channel
<
DataItem
>
Dataset
::
fetch_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
)
{
return
_data_containers
[
data_name
]
->
fetch
(
epoch_id
);
}
SampleInstancePipe
Dataset
::
fetch_sample
(
const
std
::
string
&
data_name
,
uint32_t
batch_size
,
uint64_t
epoch_id
)
{
auto
*
data_container
=
_data_containers
[
data_name
].
get
();
auto
data_channel
=
data_container
->
fetch
(
epoch_id
);
const
auto
*
data_parser
=
data_container
->
data_parser
();
PipelineOptions
options
;
options
.
batch_size
=
batch_size
;
options
.
need_hold_input_data
=
true
;
options
.
buffer_data_num
=
batch_size
*
10
;
SampleInstancePipe
pipe
=
make_sample_instance_channel
();
pipe
->
initialize
(
options
,
data_channel
,
[
data_parser
]
(
const
DataItem
*
data
,
SampleInstance
*
sample
,
size_t
num
)
->
int
{
int
ret
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
,
++
data
,
++
sample
)
{
ret
|=
data_parser
->
parse_to_sample
(
*
data
,
*
sample
);
}
return
ret
;
});
return
pipe
;
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
0 → 100644
浏览文件 @
c1d64d7e
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
Dataset
{
public:
Dataset
()
{}
virtual
~
Dataset
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
);
//触发可预取的数据判断
virtual
void
pre_detect_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
);
//获取数据状态
virtual
DatasetStatus
epoch_data_status
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
);
//返回各DataContainer内的原始数据(maybe 压缩格式)
virtual
::
paddle
::
framework
::
Channel
<
DataItem
>
fetch_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
);
//以管道形式返回标准样本流,管道内会对数据做异步转换
virtual
SampleInstancePipe
fetch_sample
(
const
std
::
string
&
data_name
,
uint32_t
batch_size
,
uint64_t
epoch_id
);
private:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
DatasetContainer
>>
_data_containers
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
浏览文件 @
c1d64d7e
...
...
@@ -8,31 +8,148 @@
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
DatasetContainer
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
_dataset_config
=
config
;
_trainer_context
=
context
.
get
();
//预取n轮样本数据
_prefetch_num
=
config
[
"prefetch_num"
].
as
<
int
>
();
_dataset_list
.
resize
(
_prefetch_num
);
_data_root_paths
=
paddle
::
string
::
split_string
(
config
[
"root_path"
].
as
<
std
::
string
>
(),
" "
);
_data_split_interval
=
config
[
"data_spit_interval"
].
as
<
int
>
();
_data_path_formater
=
config
[
"data_path_formater"
].
as
<
std
::
string
>
();
std
::
string
data_reader_class
=
config
[
"data_reader"
].
as
<
std
::
string
>
();
DataReader
*
data_reader
=
CREATE_CLASS
(
DataReader
,
data_reader_class
);
_data_reader
.
reset
(
data_reader
);
return
_data_reader
->
initialize
(
config
,
context
);
}
std
::
shared_ptr
<
DatasetInfo
>
DatasetContainer
::
dataset
(
uint64_t
timestamp
)
{
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
auto
data_idx
=
timestamp
/
epoch_accessor
->
epoch_time_interval
();
return
_dataset_list
[
data_idx
%
_prefetch_num
];
}
void
DatasetContainer
::
pre_detect_data
(
uint64_t
epoch_id
)
{
int
status
=
0
;
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
time_t
timestamp
=
epoch_accessor
->
epoch_timestamp
(
epoch_id
);
if
(
timestamp
%
epoch_accessor
->
epoch_time_interval
()
!=
0
)
{
LOG
(
FATAL
)
<<
"timestamp:"
<<
timestamp
<<
" don't match interval:"
<<
epoch_accessor
->
epoch_time_interval
();
return
;
}
size_t
data_num
=
data_num_for_train
(
timestamp
,
epoch_accessor
->
epoch_time_interval
(),
_data_split_interval
);
uint64_t
data_timestamp
=
timestamp
%
_data_split_interval
==
0
?
timestamp
:
(
timestamp
/
_data_split_interval
+
1
)
*
_data_split_interval
;
std
::
vector
<
std
::
string
>
data_path_list
;
for
(
int
i
=
0
;
i
<
_data_root_paths
.
size
()
&&
status
==
0
;
++
i
)
{
for
(
int
j
=
0
;
j
<
data_num
&&
status
==
0
;
++
j
)
{
std
::
string
path_suffix
=
format_timestamp
(
data_timestamp
+
j
*
_data_split_interval
,
_data_path_formater
);
std
::
string
data_dir
=
_data_root_paths
[
i
]
+
"/"
+
path_suffix
;
status
=
read_data_list
(
data_dir
,
data_path_list
);
}
}
if
(
status
==
0
)
{
auto
dataset_info
=
dataset
(
timestamp
);
dataset_info
->
timestamp
=
timestamp
;
dataset_info
->
file_path_list
=
std
::
move
(
data_path_list
);
dataset_info
->
status
=
DatasetStatus
::
Detected
;
}
return
;
}
int
DatasetContainer
::
read_data_list
(
const
std
::
string
&
data_dir
,
std
::
vector
<
std
::
string
>&
data_list
)
{
auto
*
environment
=
_trainer_context
->
environment
.
get
();
// 检查数据Ready
int
data_status
=
-
1
;
if
(
environment
->
is_master_node
(
EnvironmentRole
::
WORKER
))
{
if
(
_data_reader
->
is_data_ready
(
data_dir
))
{
data_status
=
0
;
}
}
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
data_status
;
environment
->
bcast
(
ar
,
0
,
EnvironmentRole
::
WORKER
);
ar
>>
data_status
;
if
(
data_status
!=
0
)
{
return
-
1
;
}
// 读取文件列表
ar
.
Clear
();
std
::
vector
<
std
::
string
>
data_path_list
;
if
(
environment
->
is_master_node
(
EnvironmentRole
::
WORKER
))
{
data_path_list
=
_data_reader
->
data_file_list
(
data_dir
);
ar
<<
data_path_list
;
}
environment
->
bcast
(
ar
,
0
,
EnvironmentRole
::
WORKER
);
ar
>>
data_path_list
;
auto
worker_id
=
environment
->
rank_id
(
EnvironmentRole
::
WORKER
);
auto
worker_num
=
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
for
(
int
i
=
worker_id
;
i
<
data_path_list
.
size
();
i
+=
worker_num
)
{
data_list
.
push_back
(
data_path_list
[
i
]);
}
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
return
0
;
}
DatasetStatus
DatasetContainer
::
epoch_data_status
(
uint64_t
epoch_id
)
{
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
time_t
timestamp
=
epoch_accessor
->
epoch_timestamp
(
epoch_id
);
return
data_status
(
timestamp
);
}
DatasetStatus
DatasetContainer
::
data_status
(
uint64_t
timestamp
)
{
auto
dataset_info
=
dataset
(
timestamp
);
if
(
dataset_info
->
timestamp
!=
timestamp
)
{
return
DatasetStatus
::
Empty
;
}
return
dataset_info
->
status
;
}
paddle
::
framework
::
Channel
<
DataItem
>
DatasetContainer
::
fetch
(
in
t
epoch_id
)
{
paddle
::
framework
::
Channel
<
DataItem
>
DatasetContainer
::
fetch
(
uint64_
t
epoch_id
)
{
paddle
::
framework
::
Channel
<
DataItem
>
result
;
if
(
_ready_epoch_id
<
epoch_id
)
{
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
time_t
timestamp
=
epoch_accessor
->
epoch_timestamp
(
epoch_id
);
if
(
data_status
(
timestamp
)
!=
DatasetStatus
::
Ready
)
{
return
result
;
}
_current_epoch_id
=
epoch_id
;
_current_dataset_idx
=
epoch_id
%
_prefetch_num
;
//result = _dataset_list[_current_dataset_idx].fetch();
//_dataset_list[_current_dataset_idx].reset((decltype(result.get())*)NULL);
return
result
;
auto
dataset_info
=
dataset
(
timestamp
);
return
dataset_info
->
data_channel
;
}
void
DatasetContainer
::
async_download_data
()
{
void
DatasetContainer
::
async_download_data
(
uint64_t
start_timestamp
)
{
auto
*
epoch_accessor
=
_trainer_context
->
epoch_accessor
.
get
();
if
(
start_timestamp
%
epoch_accessor
->
epoch_time_interval
()
!=
0
)
{
LOG
(
FATAL
)
<<
"timestamp:"
<<
start_timestamp
<<
" don't match interval:"
<<
epoch_accessor
->
epoch_time_interval
();
return
;
}
while
(
true
)
{
//do download
sleep
(
30
);
auto
dataset_info
=
dataset
(
start_timestamp
);
while
(
data_status
(
start_timestamp
)
!=
DatasetStatus
::
Detected
)
{
sleep
(
30
);
}
const
auto
&
file_list
=
dataset_info
->
file_path_list
;
dataset_info
->
data_channel
->
Clear
();
while
(
_data_reader
->
read_all
(
file_list
,
dataset_info
->
data_channel
)
!=
0
)
{
dataset_info
->
data_channel
->
Clear
();
VLOG
(
0
)
<<
"timestamp:"
<<
start_timestamp
<<
" data read failed, retry"
;
sleep
(
30
);
}
start_timestamp
+=
epoch_accessor
->
epoch_time_interval
();
}
}
}
//
namespace feed
}
//
namespace custom_trainer
}
//
namespace paddle
}
//
namespace feed
}
//
namespace custom_trainer
}
//
namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h
浏览文件 @
c1d64d7e
...
...
@@ -16,38 +16,61 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
inline
int
data_num_for_train
(
uint64_t
train_begin_timestamp
,
uint32_t
train_time_interval
,
uint32_t
data_time_interval
)
{
uint64_t
data_begin_time
=
train_begin_timestamp
;
uint64_t
data_end_time
=
data_begin_time
+
train_time_interval
;
uint64_t
end_idx
=
(
data_end_time
+
data_time_interval
-
1
)
/
data_time_interval
;
uint64_t
begin_idx
=
(
data_begin_time
+
data_time_interval
-
1
)
/
data_time_interval
;
return
end_idx
-
begin_idx
;
}
enum
class
DatasetStatus
{
Empty
=
0
,
Detected
=
1
,
Downloding
=
2
,
Ready
=
3
};
struct
DatasetInfo
{
uint64_t
timestamp
=
0
;
std
::
vector
<
std
::
string
>
file_path_list
;
DatasetStatus
status
=
DatasetStatus
::
Empty
;
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
=
::
paddle
::
framework
::
MakeChannel
<
DataItem
>
();
};
class
DatasetContainer
{
public:
DatasetContainer
()
{}
virtual
~
DatasetContainer
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
)
{
_dataset_config
=
config
;
//预取n轮样本数据
_prefetch_num
=
config
[
"prefetch_num"
].
as
<
int
>
();
_data_root_path
=
config
[
"root_path"
].
as
<
std
::
string
>
();
_data_path_generater
=
config
[
"_data_path_generater"
].
as
<
std
::
string
>
();
return
0
;
}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
);
virtual
void
run
();
//获取特定epoch_i样本,如果数据未ready,Channel内为空指针
virtual
::
paddle
::
framework
::
Channel
<
DataItem
>
fetch
(
int
epoch_id
);
//触发可预取的数据判断
virtual
void
pre_detect_data
(
RuntimeEnvironment
*
env
);
virtual
void
pre_detect_data
(
uint64_t
epoch_id
);
//获取数据状态
virtual
DatasetStatus
epoch_data_status
(
uint64_t
epoch_id
);
//获取特定epoch_i样本,如果数据未ready,Channel内为空指针
virtual
::
paddle
::
framework
::
Channel
<
DataItem
>
fetch
(
uint64_t
epoch_id
);
//获取DataItem解析器
virtual
const
DataParser
*
data_parser
()
{
return
_data_reader
->
get_parser
();
}
protected:
virtual
DatasetStatus
data_status
(
uint64_t
timestamp
);
virtual
int
read_data_list
(
const
std
::
string
&
data_dir
,
std
::
vector
<
std
::
string
>&
data_list
);
//异步样本download
virtual
void
async_download_data
();
virtual
void
download
(
int
epoch_id
,
const
std
::
vector
<
std
::
string
>&
paths
);
virtual
void
async_download_data
(
uint64_t
start_timestamp
);
virtual
std
::
shared_ptr
<
DatasetInfo
>
dataset
(
uint64_t
timestamp
);
int
_prefetch_num
=
0
;
int
_prefetch_num
=
0
;
int
_data_split_interval
=
60
;
//样本切分周期(秒)
YAML
::
Node
_dataset_config
;
std
::
string
_data_
root_path
;
std
::
string
_data_path_generater
;
std
::
string
_data_
path_formater
;
std
::
vector
<
std
::
string
>
_data_root_paths
;
//支持同时读取多个目录
uint32_t
_current_dataset_idx
;
//当前样本数据idx
int
_current_epoch_id
=
-
1
;
int
_ready_epoch_id
=
-
1
;
//已下载完成的epoch_id
std
::
vector
<
std
::
shared_ptr
<
::
paddle
::
framework
::
Dataset
>>
_dataset_list
;
//预取的数据列表
TrainerContext
*
_trainer_context
;
std
::
shared_ptr
<
DataReader
>
_data_reader
;
std
::
shared_ptr
<
std
::
thread
>
_downloader_thread
;
std
::
vector
<
std
::
shared_ptr
<
DatasetInfo
>>
_dataset_list
;
//预取的数据列表
};
}
//namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
c1d64d7e
...
...
@@ -35,7 +35,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return
0
;
}
std
::
future
<
int
>
LearnerProcess
::
save_model
(
in
t
epoch_id
,
int
table_id
,
ModelSaveWay
way
)
{
std
::
future
<
int
>
LearnerProcess
::
save_model
(
uint64_
t
epoch_id
,
int
table_id
,
ModelSaveWay
way
)
{
std
::
promise
<
int
>
p
;
auto
ret
=
p
.
get_future
();
if
(
_context_ptr
->
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
...
...
@@ -47,7 +47,7 @@ std::future<int> LearnerProcess::save_model(int epoch_id, int table_id, ModelSav
return
ret
;
}
int
LearnerProcess
::
wait_save_model
(
in
t
epoch_id
,
ModelSaveWay
way
)
{
int
LearnerProcess
::
wait_save_model
(
uint64_
t
epoch_id
,
ModelSaveWay
way
)
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
if
(
!
environment
->
is_master_node
())
{
return
0
;
...
...
@@ -71,7 +71,7 @@ int LearnerProcess::wait_save_model(int epoch_id, ModelSaveWay way) {
int
LearnerProcess
::
run
()
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
in
t
epoch_id
=
epoch_accessor
->
current_epoch_id
();
uint64_
t
epoch_id
=
epoch_accessor
->
current_epoch_id
();
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Resume traine with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
浏览文件 @
c1d64d7e
...
...
@@ -21,9 +21,9 @@ public:
protected:
//同步保存所有模型
virtual
int
wait_save_model
(
in
t
epoch_id
,
ModelSaveWay
way
);
virtual
int
wait_save_model
(
uint64_
t
epoch_id
,
ModelSaveWay
way
);
//异步保存指定模型
virtual
std
::
future
<
int
>
save_model
(
in
t
epoch_id
,
int
table_id
,
ModelSaveWay
way
);
virtual
std
::
future
<
int
>
save_model
(
uint64_
t
epoch_id
,
int
table_id
,
ModelSaveWay
way
);
//执行指定训练网络
virtual
int
run_executor
(
Executor
*
executor
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录