Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
d7ee6ba1
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d7ee6ba1
编写于
8月 13, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
差异文件
commit local runtimeEnvironment
上级
c1d64d7e
1a032f40
变更
40
展开全部
隐藏空白更改
内联
并排
Showing
40 changed file
with
2343 addition
and
15076 deletion
+2343
-15076
BCLOUD
BCLOUD
+16
-25
BCLOUD.paddle
BCLOUD.paddle
+78
-0
paddle/fluid/framework/channel.h
paddle/fluid/framework/channel.h
+1
-1
paddle/fluid/framework/io/fs.cc
paddle/fluid/framework/io/fs.cc
+20
-1
paddle/fluid/framework/io/fs.h
paddle/fluid/framework/io/fs.h
+4
-0
paddle/fluid/string/string_helper.h
paddle/fluid/string/string_helper.h
+12
-0
paddle/fluid/train/custom_trainer/feed/.clang-format
paddle/fluid/train/custom_trainer/feed/.clang-format
+33
-0
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+39
-8
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+37
-10
paddle/fluid/train/custom_trainer/feed/accessor/me
paddle/fluid/train/custom_trainer/feed/accessor/me
+0
-14840
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
+1
-0
paddle/fluid/train/custom_trainer/feed/common/registerer.h
paddle/fluid/train/custom_trainer/feed/common/registerer.h
+1
-1
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
...d/train/custom_trainer/feed/common/runtime_environment.cc
+50
-8
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
...id/train/custom_trainer/feed/common/runtime_environment.h
+6
-7
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
+30
-2
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
...le/fluid/train/custom_trainer/feed/dataset/data_reader.cc
+210
-0
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+7
-11
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
+20
-6
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
+2
-0
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
...id/train/custom_trainer/feed/dataset/dataset_container.cc
+29
-15
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h
...uid/train/custom_trainer/feed/dataset/dataset_container.h
+8
-2
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
+67
-69
paddle/fluid/train/custom_trainer/feed/executor/executor.h
paddle/fluid/train/custom_trainer/feed/executor/executor.h
+0
-12
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
...le/fluid/train/custom_trainer/feed/io/auto_file_system.cc
+129
-0
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
+42
-0
paddle/fluid/train/custom_trainer/feed/io/file_system.h
paddle/fluid/train/custom_trainer/feed/io/file_system.h
+59
-0
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
.../fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
+210
-0
paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc
...e/fluid/train/custom_trainer/feed/io/local_file_system.cc
+136
-0
paddle/fluid/train/custom_trainer/feed/io/shell.cc
paddle/fluid/train/custom_trainer/feed/io/shell.cc
+380
-0
paddle/fluid/train/custom_trainer/feed/io/shell.h
paddle/fluid/train/custom_trainer/feed/io/shell.h
+78
-0
paddle/fluid/train/custom_trainer/feed/model/epoch_donefile.txt
.../fluid/train/custom_trainer/feed/model/epoch_donefile.txt
+3
-0
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
...uid/train/custom_trainer/feed/process/init_env_process.cc
+27
-6
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+14
-11
paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so
...train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so
+0
-0
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+9
-4
paddle/fluid/train/custom_trainer/feed/unit_test/main.cc
paddle/fluid/train/custom_trainer/feed/unit_test/main.cc
+1
-0
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
...id/train/custom_trainer/feed/unit_test/test_datareader.cc
+278
-0
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc
...rain/custom_trainer/feed/unit_test/test_datareader_omp.cc
+211
-0
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
...luid/train/custom_trainer/feed/unit_test/test_executor.cc
+89
-37
release.bcloud
release.bcloud
+6
-0
未找到文件。
BCLOUD
浏览文件 @
d7ee6ba1
此差异已折叠。
点击以展开。
BCLOUD.paddle
0 → 100644
浏览文件 @
d7ee6ba1
此差异已折叠。
点击以展开。
paddle/fluid/framework/channel.h
浏览文件 @
d7ee6ba1
...
...
@@ -332,7 +332,7 @@ class ChannelReader {
}
if
(
cursor_
>=
buffer_
.
size
())
{
cursor_
=
0
;
if
(
channel_
->
r
ead
(
buffer_
)
==
0
)
{
if
(
channel_
->
R
ead
(
buffer_
)
==
0
)
{
failed_
=
true
;
return
*
this
;
}
...
...
paddle/fluid/framework/io/fs.cc
浏览文件 @
d7ee6ba1
...
...
@@ -149,7 +149,7 @@ std::vector<std::string> localfs_list(const std::string& path) {
std
::
shared_ptr
<
FILE
>
pipe
;
int
err_no
=
0
;
pipe
=
shell_popen
(
string
::
format_string
(
"find %s -
type f -maxdepth 1
"
,
path
.
c_str
()),
"r"
,
string
::
format_string
(
"find %s -
maxdepth 1 -type f
"
,
path
.
c_str
()),
"r"
,
&
err_no
);
string
::
LineFileReader
reader
;
std
::
vector
<
std
::
string
>
list
;
...
...
@@ -452,5 +452,24 @@ void fs_mkdir(const std::string& path) {
LOG
(
FATAL
)
<<
"Not supported"
;
}
}
std
::
string
fs_path_join
(
const
std
::
string
&
dir
,
const
std
::
string
&
path
)
{
if
(
dir
.
empty
())
{
return
path
;
}
if
(
dir
.
back
()
==
'/'
)
{
return
dir
+
path
;
}
return
dir
+
'/'
+
path
;
}
std
::
pair
<
std
::
string
,
std
::
string
>
fs_path_split
(
const
std
::
string
&
path
)
{
size_t
pos
=
path
.
find_last_of
(
'/'
);
if
(
pos
==
std
::
string
::
npos
)
{
return
{
"."
,
path
};
}
return
{
path
.
substr
(
0
,
pos
),
path
.
substr
(
pos
+
1
)};
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/io/fs.h
浏览文件 @
d7ee6ba1
...
...
@@ -97,5 +97,9 @@ extern std::string fs_tail(const std::string& path);
extern
bool
fs_exists
(
const
std
::
string
&
path
);
extern
void
fs_mkdir
(
const
std
::
string
&
path
);
extern
std
::
string
fs_path_join
(
const
std
::
string
&
dir
,
const
std
::
string
&
path
);
extern
std
::
pair
<
std
::
string
,
std
::
string
>
fs_path_split
(
const
std
::
string
&
path
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/string/string_helper.h
浏览文件 @
d7ee6ba1
...
...
@@ -136,6 +136,18 @@ std::string join_strings(const Container& strs, char delim) {
return
str
;
}
static
inline
bool
end_with
(
const
std
::
string
&
main_str
,
const
std
::
string
&
str
)
{
return
main_str
.
length
()
>=
str
.
length
()
&&
strncmp
(
main_str
.
c_str
()
+
main_str
.
length
()
-
str
.
length
(),
str
.
c_str
(),
str
.
length
())
==
0
;
}
static
inline
bool
begin_with
(
const
std
::
string
&
main_str
,
const
std
::
string
&
str
)
{
return
main_str
.
length
()
>=
str
.
length
()
&&
strncmp
(
main_str
.
c_str
(),
str
.
c_str
(),
str
.
length
())
==
0
;
}
// A helper class for reading lines from file. A line buffer is maintained. It
// doesn't need to know the maximum possible length of a line.
...
...
paddle/fluid/train/custom_trainer/feed/.clang-format
0 → 100644
浏览文件 @
d7ee6ba1
BasedOnStyle: Google
AccessModifierOffset: -4
AlignAfterOpenBracket: AlwaysBreak
AlignOperands: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BreakConstructorInitializers: AfterColon
ColumnLimit: 100
ConstructorInitializerIndentWidth: 8
ContinuationIndentWidth: 8
DerivePointerAlignment: true
FixNamespaceComments: true
IndentCaseLabels: false
IndentWidth: 4
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 500
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 400
PointerAlignment: Left
SortIncludes: false
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
d7ee6ba1
#pragma once
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
EpochAccessor
::
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
_model_root_path
=
config
[
"model_root_path"
].
as
<
std
::
string
>
()
+
"/"
;
_done_file_path
=
_model_root_path
;
if
(
config
[
"donefile"
])
{
_done_file_path
.
append
(
config
[
"donefile"
].
as
<
std
::
string
>
());
}
else
{
_done_file_path
.
append
(
"epoch_donefile.txt"
);
}
if
(
!
context_ptr
->
file_system
->
exists
(
_done_file_path
))
{
VLOG
(
0
)
<<
"missing done file, path:"
<<
_done_file_path
;
}
std
::
string
done_text
=
context_ptr
->
file_system
->
tail
(
_done_file_path
);
_done_status
=
paddle
::
string
::
split_string
(
done_text
,
std
::
string
(
"
\t
"
));
_current_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
EpochIdField
);
_last_checkpoint_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
CheckpointIdField
);
_last_checkpoint_path
=
get_status
<
std
::
string
>
(
EpochStatusFiled
::
CheckpointPathField
);
return
0
;
}
int
HourlyEpochAccessor
::
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
EpochAccessor
::
initialize
(
config
,
context_ptr
);
return
0
;
}
void
HourlyEpochAccessor
::
next_epoch
()
{
_current_epoch_id
=
next_epoch_id
(
_current_epoch_id
);
}
std
::
string
HourlyEpochAccessor
::
text
(
uint64_t
epoch_id
)
{
return
std
::
to_string
(
epoch_id
);
return
format_timestamp
(
epoch_id
,
"%Y%m%d delta-%H"
);
}
bool
HourlyEpochAccessor
::
data_ready
(
uint64_t
epoch_id
)
{
return
true
;
}
int
HourlyEpochAccessor
::
next_epoch_id
(
uint64_t
epoch_id
)
{
uint64_t
HourlyEpochAccessor
::
next_epoch_id
(
uint64_t
epoch_id
)
{
if
(
epoch_id
==
0
)
{
struct
timeval
now
;
gettimeofday
(
&
now
,
NULL
);
...
...
@@ -25,15 +50,19 @@ namespace feed {
}
return
epoch_id
+
3600
;
}
bool
HourlyEpochAccessor
::
is_last_epoch
(
uint64_t
epoch_id
)
{
return
((
epoch_id
/
3600
)
%
24
)
==
23
;
}
}
uint64_t
HourlyEpochAccessor
::
epoch_time_interval
()
{
return
3600
;
}
uint64_t
HourlyEpochAccessor
::
epoch_timestamp
(
uint64_t
epoch_id
)
{
return
epoch_id
;
}
}
bool
HourlyEpochAccessor
::
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
epoch_id
==
0
)
{
return
false
;
...
...
@@ -47,6 +76,7 @@ namespace feed {
}
return
false
;
}
std
::
string
HourlyEpochAccessor
::
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
return
_model_root_path
+
"/xbox/delta-"
+
std
::
to_string
(
epoch_id
);
...
...
@@ -57,6 +87,7 @@ namespace feed {
}
return
""
;
}
REGISTER_CLASS
(
EpochAccessor
,
HourlyEpochAccessor
);
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
浏览文件 @
d7ee6ba1
#pragma once
#include <boost/lexical_cast.hpp>
#include "paddle/fluid/string/to_string.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
...
...
@@ -6,20 +9,41 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
enum
class
EpochStatusFiled
{
DateField
=
0
,
TimestampField
=
1
,
CheckpointPathField
=
2
,
EpochIdField
=
3
,
CheckpointIdField
=
4
};
class
EpochAccessor
:
public
Accessor
{
public:
EpochAccessor
()
{}
virtual
~
EpochAccessor
()
{}
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
=
0
;
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
uint64_t
current_epoch_id
()
{
return
_current_epoch_id
;
}
virtual
void
next_epoch
()
=
0
;
virtual
std
::
string
text
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
data_ready
(
uint64_t
epoch_id
)
=
0
;
virtual
int
next_epoch_id
(
uint64_t
epoch_id
)
=
0
;
virtual
const
std
::
string
&
checkpoint_path
()
{
return
_last_checkpoint_path
;
}
template
<
class
T
>
T
get_status
(
EpochStatusFiled
field
)
{
auto
status
=
paddle
::
string
::
trim_spaces
(
_done_status
[
static_cast
<
int
>
(
field
)]);
return
boost
::
lexical_cast
<
T
>
(
status
.
c_str
());
}
virtual
void
next_epoch
()
=
0
;
virtual
std
::
string
model_root_path
()
{
return
_model_root_path
;
}
virtual
std
::
string
text
(
uint64_t
epoch_id
)
=
0
;
virtual
uint64_t
next_epoch_id
(
uint64_t
epoch_id
)
=
0
;
virtual
bool
is_last_epoch
(
uint64_t
epoch_id
)
=
0
;
//epoch间的数据时间间隔(秒)
virtual
uint64_t
epoch_time_interval
()
=
0
;
...
...
@@ -28,7 +52,13 @@ public:
virtual
bool
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
std
::
string
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
protected:
uint64_t
_current_epoch_id
;
std
::
string
_done_file_path
;
std
::
string
_model_root_path
;
uint64_t
_current_epoch_id
=
0
;
std
::
string
_last_checkpoint_path
;
uint64_t
_last_checkpoint_epoch_id
=
0
;
std
::
vector
<
std
::
string
>
_done_status
;
//当前完成状态,统一存成string
};
REGISTER_REGISTERER
(
EpochAccessor
);
...
...
@@ -40,15 +70,12 @@ public:
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
void
next_epoch
();
virtual
std
::
string
text
(
uint64_t
epoch_id
);
virtual
bool
data_ready
(
uint64_t
epoch_id
);
virtual
int
next_epoch_id
(
uint64_t
epoch_id
);
virtual
uint64_t
next_epoch_id
(
uint64_t
epoch_id
);
virtual
bool
is_last_epoch
(
uint64_t
epoch_id
);
virtual
uint64_t
epoch_time_interval
();
virtual
uint64_t
epoch_timestamp
(
uint64_t
epoch_id
);
virtual
bool
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
);
virtual
std
::
string
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
);
private:
std
::
string
_model_root_path
;
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/accessor/me
已删除
100644 → 0
浏览文件 @
c1d64d7e
此差异已折叠。
点击以展开。
paddle/fluid/train/custom_trainer/feed/common/pipeline.h
浏览文件 @
d7ee6ba1
#pragma once
#include <thread>
#include "paddle/fluid/framework/archive.h"
namespace
paddle
{
...
...
paddle/fluid/train/custom_trainer/feed/common/registerer.h
浏览文件 @
d7ee6ba1
...
...
@@ -106,7 +106,7 @@ BaseClassMap& global_factory_map_cpp();
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name)
;
base_class##Registerer::CreateInstanceByName(name)
}
//namespace feed
}
//namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
浏览文件 @
d7ee6ba1
#include <mpi.h>
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace
paddle
{
...
...
@@ -68,26 +69,67 @@ public:
virtual
void
barrier
(
EnvironmentRole
role
)
{
MPI_Barrier
(
mpi_node_info
(
role
).
mpi_comm
);
}
virtual
void
bcast
(
paddle
::
framework
::
BinaryArchive
&
ar
,
int
root_id
,
EnvironmentRole
role
)
{
auto
&
node_info
=
mpi_node_info
(
role
);
int
len
=
(
int
)
ar
.
l
ength
();
int
len
=
(
int
)
ar
.
L
ength
();
MPI_Bcast
(
&
len
,
1
,
MPI_INT
,
root_id
,
node_info
.
mpi_comm
);
ar
.
r
esize
(
len
);
ar
.
set_cursor
(
ar
.
b
uffer
());
MPI_Bcast
(
ar
.
buffer
(),
len
,
MPI_BYTE
,
root
,
node_info
.
mpi_comm
);
ar
.
R
esize
(
len
);
ar
.
SetCursor
(
ar
.
B
uffer
());
MPI_Bcast
(
ar
.
Buffer
(),
len
,
MPI_BYTE
,
root_id
,
node_info
.
mpi_comm
);
}
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
);
virtual
void
print_log
(
EnvironmentRole
role
,
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
{
if
(
type
==
EnvironmentLogType
::
MASTER_LOG
&&
!
is_master_node
(
role
))
{
return
;
}
VLOG
(
static_cast
<
int
>
(
level
))
<<
log_str
;
}
inline
MpiNodeInfo
&
mpi_node_info
(
EnvironmentRole
role
)
{
return
_roles_node_info
[
static_cast
<
int
>
(
role
)];
}
private:
std
::
vector
<
MpiNodeInfo
>
_roles_node_info
;
};
REGISTER_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
//用于本地模式单机训练
class
LocalRuntimeEnvironment
:
public
RuntimeEnvironment
{
public:
LocalRuntimeEnvironment
()
{}
virtual
~
LocalRuntimeEnvironment
()
{}
virtual
int
initialize
(
YAML
::
Node
config
)
{
return
0
;
}
virtual
int
wireup
()
{
return
0
;
}
virtual
uint32_t
rank_id
(
EnvironmentRole
role
)
{
return
0
;
}
virtual
uint32_t
node_num
(
EnvironmentRole
role
)
{
return
1
;
}
virtual
int
set_role
(
EnvironmentRole
role
)
{
return
0
;
}
virtual
void
barrier
(
EnvironmentRole
role
)
{
return
;
}
virtual
void
bcast
(
paddle
::
framework
::
BinaryArchive
&
ar
,
int
root_id
,
EnvironmentRole
role
)
{
return
;
}
protected:
virtual
void
print_log
(
EnvironmentRole
role
,
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
{
VLOG
(
static_cast
<
int
>
(
level
))
<<
log_str
;
}
};
REGISTER_CLASS
(
RuntimeEnvironment
,
LocalRuntimeEnvironment
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
浏览文件 @
d7ee6ba1
...
...
@@ -55,9 +55,9 @@ public:
//环境定制化log
template
<
class
...
ARGS
>
void
log
(
Environment
LogType
type
,
EnvironmentLogLevel
level
,
const
char
*
fmt
,
ARGS
&&
...
args
)
{
print_log
(
type
,
level
,
paddle
::
string
::
format_string
(
fmt
,
args
...));
void
log
(
Environment
Role
role
,
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
char
*
fmt
,
ARGS
&&
...
args
)
{
print_log
(
role
,
type
,
level
,
paddle
::
string
::
format_string
(
fmt
,
args
...));
}
//多线程可调用接口 End
...
...
@@ -69,14 +69,13 @@ public:
virtual
void
bcast
(
paddle
::
framework
::
BinaryArchive
&
ar
,
int
root_id
,
EnvironmentRole
role
)
=
0
;
//接口只允许在主线程调用 End
protected:
virtual
void
print_log
(
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
=
0
;
virtual
void
print_log
(
EnvironmentRole
role
,
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
=
0
;
};
REGISTER_REGISTERER
(
RuntimeEnvironment
);
std
::
string
format_timestamp
(
time_t
time
,
const
char
*
format
);
std
::
string
format_timestamp
(
time_t
time
,
const
std
::
string
&
format
)
{
inline
std
::
string
format_timestamp
(
time_t
time
,
const
std
::
string
&
format
)
{
return
format_timestamp
(
time
,
format
.
c_str
());
}
...
...
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
浏览文件 @
d7ee6ba1
train_thread_num
:
10
environment
:
environment_class
:
MPIRuntimeEnvironment
environment_class
:
LocalRuntimeEnvironment
io
:
file_systems
:
afs
:
class
:
HadoopFileSystem
buffer_size
:
1024000
ugis
:
'
default'
:
'
feed_video,D3a0z8'
'
xingtian.afs.baidu.com:9902'
:
'
feed_video,D3a0z8'
local
:
class
:
LocalFileSystem
buffer_size
:
1024000
dataset
:
data_list
:
train_sample
:
prefetch_num
:
2
root_path
:
./sample
data_spit_interval
:
300
data_path_formater
:
'
%Y%m%d/%H%M'
data_reader
:
LineDataReader
done_file
:
to.hadoop.done
filename_prefix
:
part
pipeline_cmd
:
cat
parser
:
class
:
LineDataParser
epoch
:
epoch_class
:
HourlyEpochAccessor
epoch_class
:
HourlyEpochAccessor
model_root_path
:
./model/
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
0 → 100644
浏览文件 @
d7ee6ba1
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include <cstdio>
#include <atomic>
#include <glog/logging.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
LineDataParser
:
public
DataParser
{
public:
LineDataParser
()
{}
virtual
~
LineDataParser
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
return
0
;
}
virtual
int
parse
(
const
char
*
str
,
size_t
len
,
DataItem
&
data
)
const
{
size_t
pos
=
0
;
while
(
pos
<
len
&&
str
[
pos
]
!=
' '
)
{
++
pos
;
}
if
(
pos
>=
len
)
{
VLOG
(
2
)
<<
"fail to parse line: "
<<
std
::
string
(
str
,
len
)
<<
", strlen: "
<<
len
;
return
-
1
;
}
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
<<
", len: "
<<
len
;
data
.
id
.
assign
(
str
,
pos
);
data
.
data
.
assign
(
str
+
pos
+
1
,
len
-
pos
-
1
);
return
0
;
}
virtual
int
parse
(
const
char
*
str
,
DataItem
&
data
)
const
{
size_t
pos
=
0
;
while
(
str
[
pos
]
!=
'\0'
&&
str
[
pos
]
!=
' '
)
{
++
pos
;
}
if
(
str
[
pos
]
==
'\0'
)
{
VLOG
(
2
)
<<
"fail to parse line: "
<<
str
<<
", get '
\\
0' at pos: "
<<
pos
;
return
-
1
;
}
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
;
data
.
id
.
assign
(
str
,
pos
);
data
.
data
.
assign
(
str
+
pos
+
1
);
return
0
;
}
virtual
int
parse_to_sample
(
const
DataItem
&
data
,
SampleInstance
&
instance
)
const
{
return
0
;
}
};
REGISTER_CLASS
(
DataParser
,
LineDataParser
);
int
DataReader
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
_parser
.
reset
(
CREATE_CLASS
(
DataParser
,
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
()));
if
(
_parser
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to get parser: "
<<
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
();
return
-
1
;
}
if
(
_parser
->
initialize
(
config
[
"parser"
],
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to initialize parser"
<<
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
();
return
-
1
;
}
_pipeline_cmd
=
config
[
"pipeline_cmd"
].
as
<
std
::
string
>
();
return
0
;
}
class
LineDataReader
:
public
DataReader
{
public:
LineDataReader
()
{}
virtual
~
LineDataReader
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
if
(
DataReader
::
initialize
(
config
,
context
)
!=
0
)
{
return
-
1
;
}
_done_file_name
=
config
[
"done_file"
].
as
<
std
::
string
>
();
_filename_prefix
=
config
[
"filename_prefix"
].
as
<
std
::
string
>
(
""
);
if
(
config
[
"file_system"
]
&&
config
[
"file_system"
][
"class"
])
{
_file_system
.
reset
(
CREATE_CLASS
(
FileSystem
,
config
[
"file_system"
][
"class"
].
as
<
std
::
string
>
()));
if
(
_file_system
==
nullptr
||
_file_system
->
initialize
(
config
[
"file_system"
],
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to create class: "
<<
config
[
"file_system"
][
"class"
].
as
<
std
::
string
>
();
return
-
1
;
}
}
else
if
(
context
->
file_system
!=
nullptr
)
{
_file_system
=
context
->
file_system
;
}
else
{
_file_system
.
reset
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
if
(
_file_system
==
nullptr
||
_file_system
->
initialize
(
YAML
::
Load
(
""
),
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to init file system"
;
return
-
1
;
}
}
return
0
;
}
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
{
auto
done_file_path
=
_file_system
->
path_join
(
data_dir
,
_done_file_name
);
if
(
_file_system
->
exists
(
done_file_path
))
{
return
true
;
}
return
false
;
}
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
)
{
std
::
vector
<
std
::
string
>
data_files
;
for
(
auto
&
filepath
:
_file_system
->
list
(
data_dir
))
{
auto
filename
=
_file_system
->
path_split
(
filepath
).
second
;
if
(
filename
!=
_done_file_name
&&
string
::
begin_with
(
filename
,
_filename_prefix
))
{
data_files
.
push_back
(
std
::
move
(
filepath
));
}
}
return
data_files
;
}
//读取数据样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
framework
::
Channel
<
DataItem
>
data_channel
)
{
auto
file_list
=
data_file_list
(
data_dir
);
return
read_all
(
file_list
,
data_channel
);
}
virtual
int
read_all
(
const
std
::
vector
<
std
::
string
>&
file_list
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
{
auto
deleter
=
[](
framework
::
ChannelWriter
<
DataItem
>
*
writer
)
{
if
(
writer
)
{
writer
->
Flush
();
VLOG
(
3
)
<<
"writer auto flush"
;
}
delete
writer
;
};
std
::
unique_ptr
<
framework
::
ChannelWriter
<
DataItem
>
,
decltype
(
deleter
)
>
writer
(
new
framework
::
ChannelWriter
<
DataItem
>
(
data_channel
.
get
()),
deleter
);
DataItem
data_item
;
int
file_list_size
=
file_list
.
size
();
std
::
atomic
<
bool
>
is_failed
(
false
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
file_list_size
;
++
i
)
{
const
auto
&
filepath
=
file_list
[
i
];
if
(
!
is_failed
)
{
std
::
shared_ptr
<
FILE
>
fin
=
_file_system
->
open_read
(
filepath
,
_pipeline_cmd
);
if
(
fin
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to open file: "
<<
filepath
<<
", with cmd: "
<<
_pipeline_cmd
;
is_failed
=
true
;
continue
;
}
char
*
buffer
=
nullptr
;
size_t
buffer_size
=
0
;
ssize_t
line_len
=
0
;
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
fin
.
get
()))
!=
-
1
)
{
if
(
line_len
>
0
&&
buffer
[
line_len
-
1
]
==
'\n'
)
{
buffer
[
--
line_len
]
=
'\0'
;
}
if
(
line_len
<=
0
)
{
continue
;
}
if
(
_parser
->
parse
(
buffer
,
line_len
,
data_item
)
==
0
)
{
(
*
writer
)
<<
std
::
move
(
data_item
);
}
}
if
(
buffer
!=
nullptr
)
{
free
(
buffer
);
buffer
=
nullptr
;
buffer_size
=
0
;
}
if
(
ferror
(
fin
.
get
())
!=
0
)
{
VLOG
(
2
)
<<
"fail to read file: "
<<
filepath
;
is_failed
=
true
;
continue
;
}
}
if
(
_file_system
->
err_no
()
!=
0
)
{
_file_system
->
reset_err_no
();
is_failed
=
true
;
continue
;
}
}
writer
->
Flush
();
if
(
!
(
*
writer
))
{
VLOG
(
2
)
<<
"fail when write to channel"
;
is_failed
=
true
;
}
data_channel
->
Close
();
return
is_failed
?
-
1
:
0
;
}
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
}
private:
std
::
string
_done_file_name
;
// without data_dir
std
::
string
_filename_prefix
;
std
::
shared_ptr
<
FileSystem
>
_file_system
;
};
REGISTER_CLASS
(
DataReader
,
LineDataReader
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
d7ee6ba1
...
...
@@ -48,9 +48,10 @@ public:
virtual
~
DataParser
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
int
parse
(
const
std
::
string
&
str
,
DataItem
&
data
)
const
{
return
parse
(
str
.
c_str
(),
str
.
size
(),
data
);
return
parse
(
str
.
c_str
(),
data
);
}
virtual
int
parse
(
const
char
*
str
,
size_t
len
,
DataItem
&
data
)
const
=
0
;
virtual
int
parse
(
const
char
*
str
,
DataItem
&
data
)
const
=
0
;
virtual
int
parse_to_sample
(
const
DataItem
&
data
,
SampleInstance
&
instance
)
const
=
0
;
};
REGISTER_REGISTERER
(
DataParser
);
...
...
@@ -59,29 +60,24 @@ class DataReader {
public:
DataReader
()
{}
virtual
~
DataReader
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
);
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
=
0
;
//读取dir下文件列表
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
);
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
)
=
0
;
//读取目录下数据到样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>
&
data_channel
)
=
0
;
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
=
0
;
//读取指定文件列表的数据到样本流中
virtual
int
read_all
(
const
std
::
vector
<
std
::
string
>&
data_list
,
::
paddle
::
framework
::
Channel
<
DataItem
>
&
data_channel
)
=
0
;
virtual
int
read_all
(
const
std
::
vector
<
std
::
string
>&
data_list
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
=
0
;
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
}
pr
ivate
:
pr
otected
:
std
::
shared_ptr
<
DataParser
>
_parser
;
//数据格式转换
std
::
string
_pipeline_cmd
;
//将文件流,重定向到pipeline_cmd,再读入
};
REGISTER_REGISTERER
(
DataReader
);
//TODO
//可读取HDFS/DISK上数据的Reader,数据按行分隔
//HDFS/DISK - FileLineReader
}
//namespace feed
}
//namespace custom_trainer
}
//namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
浏览文件 @
d7ee6ba1
...
...
@@ -6,15 +6,14 @@ namespace feed {
int
Dataset
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
if
(
!
config
[
"data_list"
]
)
{
VLOG
(
0
)
<<
"miss data_list config in dataset, please check"
;
if
(
config
[
"data_list"
].
Type
()
!=
YAML
::
NodeType
::
Map
)
{
VLOG
(
0
)
<<
"miss data_list config in dataset,
or type error
please check"
;
return
-
1
;
}
int
data_num
=
config
[
"data_list"
].
size
();
for
(
int
i
=
0
;
i
<
data_num
;
++
i
)
{
std
::
string
name
=
config
[
"data_list"
][
i
][
"name"
].
as
<
std
::
string
>
();
for
(
auto
&
data_config
:
config
[
"data_list"
])
{
std
::
string
name
=
data_config
.
first
.
as
<
std
::
string
>
();
auto
data_ptr
=
std
::
make_shared
<
DatasetContainer
>
();
if
(
data_ptr
->
initialize
(
config
[
"data_list"
][
i
]
,
context
)
!=
0
)
{
if
(
data_ptr
->
initialize
(
data_config
.
second
,
context
)
!=
0
)
{
VLOG
(
0
)
<<
"dataset initialize failed, name:"
<<
name
;
return
-
1
;
}
...
...
@@ -23,12 +22,27 @@ int Dataset::initialize(
return
0
;
}
inline
void
Dataset
::
pre_detect_data
(
uint64_t
epoch_id
)
{
for
(
auto
it
=
_data_containers
.
begin
();
it
!=
_data_containers
.
end
();
++
it
)
{
it
->
second
->
pre_detect_data
(
epoch_id
);
}
return
;
}
inline
void
Dataset
::
pre_detect_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
)
{
_data_containers
[
data_name
]
->
pre_detect_data
(
epoch_id
);
return
;
}
inline
DatasetStatus
Dataset
::
epoch_data_status
(
uint64_t
epoch_id
)
{
int
status
=
static_cast
<
int
>
(
DatasetStatus
::
Ready
);
for
(
auto
it
=
_data_containers
.
begin
();
it
!=
_data_containers
.
end
();
++
it
)
{
auto
d_status
=
static_cast
<
int
>
(
it
->
second
->
epoch_data_status
(
epoch_id
));
status
=
d_status
<
status
?
d_status
:
status
;
}
return
static_cast
<
DatasetStatus
>
(
status
);
}
inline
DatasetStatus
Dataset
::
epoch_data_status
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
)
{
return
_data_containers
[
data_name
]
->
epoch_data_status
(
epoch_id
);
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset.h
浏览文件 @
d7ee6ba1
...
...
@@ -22,9 +22,11 @@ public:
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
);
//触发可预取的数据判断
virtual
void
pre_detect_data
(
uint64_t
epoch_id
);
virtual
void
pre_detect_data
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
);
//获取数据状态
virtual
DatasetStatus
epoch_data_status
(
uint64_t
epoch_id
);
virtual
DatasetStatus
epoch_data_status
(
const
std
::
string
&
data_name
,
uint64_t
epoch_id
);
//返回各DataContainer内的原始数据(maybe 压缩格式)
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
浏览文件 @
d7ee6ba1
...
...
@@ -23,6 +23,9 @@ int DatasetContainer::initialize(
//预取n轮样本数据
_prefetch_num
=
config
[
"prefetch_num"
].
as
<
int
>
();
_dataset_list
.
resize
(
_prefetch_num
);
for
(
int
i
=
0
;
i
<
_prefetch_num
;
++
i
)
{
_dataset_list
[
i
].
reset
(
new
DatasetInfo
);
}
_data_root_paths
=
paddle
::
string
::
split_string
(
config
[
"root_path"
].
as
<
std
::
string
>
(),
" "
);
...
...
@@ -48,21 +51,32 @@ void DatasetContainer::pre_detect_data(uint64_t epoch_id) {
LOG
(
FATAL
)
<<
"timestamp:"
<<
timestamp
<<
" don't match interval:"
<<
epoch_accessor
->
epoch_time_interval
();
return
;
}
size_t
data_num
=
data_num_for_train
(
timestamp
,
epoch_accessor
->
epoch_time_interval
(),
_data_split_interval
);
uint64_t
data_timestamp
=
timestamp
%
_data_split_interval
==
0
?
timestamp
:
(
timestamp
/
_data_split_interval
+
1
)
*
_data_split_interval
;
std
::
vector
<
std
::
string
>
data_path_list
;
for
(
int
i
=
0
;
i
<
_data_root_paths
.
size
()
&&
status
==
0
;
++
i
)
{
for
(
int
j
=
0
;
j
<
data_num
&&
status
==
0
;
++
j
)
{
std
::
string
path_suffix
=
format_timestamp
(
data_timestamp
+
j
*
_data_split_interval
,
_data_path_formater
);
std
::
string
data_dir
=
_data_root_paths
[
i
]
+
"/"
+
path_suffix
;
status
=
read_data_list
(
data_dir
,
data_path_list
);
}
if
(
_downloader_thread
==
nullptr
)
{
_downloader_thread
.
reset
(
new
std
::
thread
([
this
,
timestamp
](){
async_download_data
(
timestamp
);
}));
}
if
(
status
==
0
)
{
auto
dataset_info
=
dataset
(
timestamp
);
dataset_info
->
timestamp
=
timestamp
;
dataset_info
->
file_path_list
=
std
::
move
(
data_path_list
);
dataset_info
->
status
=
DatasetStatus
::
Detected
;
for
(
int
detect_idx
=
0
;
detect_idx
<
_prefetch_num
;
++
detect_idx
)
{
if
(
DatasetStatus
::
Empty
!=
data_status
(
timestamp
))
{
continue
;
}
size_t
data_num
=
data_num_for_train
(
timestamp
,
epoch_accessor
->
epoch_time_interval
(),
_data_split_interval
);
uint64_t
data_timestamp
=
timestamp
%
_data_split_interval
==
0
?
timestamp
:
(
timestamp
/
_data_split_interval
+
1
)
*
_data_split_interval
;
std
::
vector
<
std
::
string
>
data_path_list
;
for
(
int
i
=
0
;
i
<
_data_root_paths
.
size
()
&&
status
==
0
;
++
i
)
{
for
(
int
j
=
0
;
j
<
data_num
&&
status
==
0
;
++
j
)
{
std
::
string
path_suffix
=
format_timestamp
(
data_timestamp
+
j
*
_data_split_interval
,
_data_path_formater
);
std
::
string
data_dir
=
_data_root_paths
[
i
]
+
"/"
+
path_suffix
;
status
=
read_data_list
(
data_dir
,
data_path_list
);
}
}
if
(
status
==
0
)
{
auto
dataset_info
=
dataset
(
timestamp
);
dataset_info
->
timestamp
=
timestamp
;
dataset_info
->
file_path_list
=
std
::
move
(
data_path_list
);
dataset_info
->
status
=
DatasetStatus
::
Detected
;
}
timestamp
+=
epoch_accessor
->
epoch_time_interval
();
}
return
;
}
...
...
@@ -134,7 +148,7 @@ void DatasetContainer::async_download_data(uint64_t start_timestamp) {
LOG
(
FATAL
)
<<
"timestamp:"
<<
start_timestamp
<<
" don't match interval:"
<<
epoch_accessor
->
epoch_time_interval
();
return
;
}
while
(
true
)
{
while
(
!
_stop_download
)
{
auto
dataset_info
=
dataset
(
start_timestamp
);
while
(
data_status
(
start_timestamp
)
!=
DatasetStatus
::
Detected
)
{
sleep
(
30
);
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h
浏览文件 @
d7ee6ba1
...
...
@@ -30,6 +30,7 @@ enum class DatasetStatus {
Downloding
=
2
,
Ready
=
3
};
struct
DatasetInfo
{
uint64_t
timestamp
=
0
;
std
::
vector
<
std
::
string
>
file_path_list
;
...
...
@@ -40,10 +41,14 @@ struct DatasetInfo {
class
DatasetContainer
{
public:
DatasetContainer
()
{}
virtual
~
DatasetContainer
()
{}
virtual
~
DatasetContainer
()
{
if
(
_downloader_thread
!=
nullptr
)
{
_stop_download
=
true
;
_downloader_thread
->
join
();
}
}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
);
virtual
void
run
();
//触发可预取的数据判断
virtual
void
pre_detect_data
(
uint64_t
epoch_id
);
//获取数据状态
...
...
@@ -62,6 +67,7 @@ protected:
virtual
std
::
shared_ptr
<
DatasetInfo
>
dataset
(
uint64_t
timestamp
);
int
_prefetch_num
=
0
;
bool
_stop_download
=
false
;
int
_data_split_interval
=
60
;
//样本切分周期(秒)
YAML
::
Node
_dataset_config
;
std
::
string
_data_path_formater
;
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
浏览文件 @
d7ee6ba1
...
...
@@ -43,86 +43,84 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load(
}
struct
SimpleExecutor
::
Context
{
Context
(
const
::
paddle
::
platform
::
Place
&
place
)
:
place
(
place
),
executor
(
place
)
{
}
const
::
paddle
::
platform
::
Place
&
place
;
::
paddle
::
framework
::
Executor
executor
;
::
std
::
unique_ptr
<::
paddle
::
framework
::
ProgramDesc
>
main_program
;
::
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepare_context
;
details
::
TensorArrayBatchCleaner
tensor_array_batch_cleaner
;
};
SimpleExecutor
::
SimpleExecutor
()
{
}
SimpleExecutor
::~
SimpleExecutor
()
{
}
int
SimpleExecutor
::
initialize
(
YAML
::
Node
exe_config
,
class
SimpleExecutor
:
public
Executor
{
public:
SimpleExecutor
()
{};
virtual
~
SimpleExecutor
()
{};
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
paddle
::
framework
::
InitDevices
(
false
);
if
(
exe_config
[
"num_threads"
])
{
paddle
::
platform
::
SetNumThreads
(
exe_config
[
"num_threads"
].
as
<
int
>
());
}
else
{
paddle
::
platform
::
SetNumThreads
(
1
);
}
if
(
!
exe_config
[
"startup_program"
]
||
!
exe_config
[
"main_program"
])
{
VLOG
(
2
)
<<
"fail to load config"
;
return
-
1
;
}
paddle
::
framework
::
InitDevices
(
false
);
if
(
exe_config
[
"num_threads"
])
{
paddle
::
platform
::
SetNumThreads
(
exe_config
[
"num_threads"
].
as
<
int
>
());
}
else
{
paddle
::
platform
::
SetNumThreads
(
1
);
}
try
{
_context
.
reset
(
new
SimpleExecutor
::
Context
(
context_ptr
->
cpu_place
));
auto
startup_program
=
Load
(
&
_context
->
executor
,
exe_config
[
"startup_program"
].
as
<
std
::
string
>
());
if
(
startup_program
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to load startup_program: "
<<
exe_config
[
"startup_program"
].
as
<
std
::
string
>
();
if
(
!
exe_config
[
"startup_program"
]
||
!
exe_config
[
"main_program"
])
{
VLOG
(
2
)
<<
"fail to load config"
;
return
-
1
;
}
_context
->
executor
.
Run
(
*
startup_program
,
this
->
scope
(),
0
,
false
,
true
);
_context
->
main_program
=
Load
(
&
_context
->
executor
,
exe_config
[
"main_program"
].
as
<
std
::
string
>
());
if
(
_context
->
main_program
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to load main_program: "
<<
exe_config
[
"main_program"
].
as
<
std
::
string
>
();
try
{
_context
.
reset
(
new
SimpleExecutor
::
Context
(
context_ptr
->
cpu_place
));
auto
startup_program
=
Load
(
&
_context
->
executor
,
exe_config
[
"startup_program"
].
as
<
std
::
string
>
());
if
(
startup_program
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to load startup_program: "
<<
exe_config
[
"startup_program"
].
as
<
std
::
string
>
();
return
-
1
;
}
_context
->
executor
.
Run
(
*
startup_program
,
this
->
scope
(),
0
,
false
,
true
);
_context
->
main_program
=
Load
(
&
_context
->
executor
,
exe_config
[
"main_program"
].
as
<
std
::
string
>
());
if
(
_context
->
main_program
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to load main_program: "
<<
exe_config
[
"main_program"
].
as
<
std
::
string
>
();
return
-
1
;
}
_context
->
prepare_context
=
_context
->
executor
.
Prepare
(
*
_context
->
main_program
,
0
);
_context
->
executor
.
CreateVariables
(
*
_context
->
main_program
,
this
->
scope
(),
0
);
}
catch
(
::
paddle
::
platform
::
EnforceNotMet
&
err
)
{
VLOG
(
2
)
<<
err
.
what
();
_context
.
reset
(
nullptr
);
return
-
1
;
}
_context
->
prepare_context
=
_context
->
executor
.
Prepare
(
*
_context
->
main_program
,
0
);
_context
->
executor
.
CreateVariables
(
*
_context
->
main_program
,
this
->
scope
(),
0
);
}
catch
(
::
paddle
::
platform
::
EnforceNotMet
&
err
)
{
VLOG
(
2
)
<<
err
.
what
();
_context
.
reset
(
nullptr
);
return
-
1
;
}
return
0
;
}
int
SimpleExecutor
::
run
()
{
if
(
_context
==
nullptr
)
{
VLOG
(
2
)
<<
"need initialize before run"
;
return
-
1
;
return
0
;
}
try
{
_context
->
executor
.
RunPreparedContext
(
_context
->
prepare_context
.
get
(),
this
->
scope
(),
false
,
/* don't create local scope each time*/
false
/* don't create variable each time */
);
// For some other vector like containers not cleaned after each batch.
_context
->
tensor_array_batch_cleaner
.
CollectNoTensorVars
(
this
->
scope
());
_context
->
tensor_array_batch_cleaner
.
ResetNoTensorVars
();
}
catch
(
::
paddle
::
platform
::
EnforceNotMet
&
err
)
{
VLOG
(
2
)
<<
err
.
what
();
return
-
1
;
virtual
int
run
()
{
if
(
_context
==
nullptr
)
{
VLOG
(
2
)
<<
"need initialize before run"
;
return
-
1
;
}
try
{
_context
->
executor
.
RunPreparedContext
(
_context
->
prepare_context
.
get
(),
this
->
scope
(),
false
,
/* don't create local scope each time*/
false
/* don't create variable each time */
);
// For some other vector like containers not cleaned after each batch.
_context
->
tensor_array_batch_cleaner
.
CollectNoTensorVars
(
this
->
scope
());
_context
->
tensor_array_batch_cleaner
.
ResetNoTensorVars
();
}
catch
(
::
paddle
::
platform
::
EnforceNotMet
&
err
)
{
VLOG
(
2
)
<<
err
.
what
();
return
-
1
;
}
return
0
;
}
return
0
;
}
protected:
struct
Context
{
Context
(
const
::
paddle
::
platform
::
Place
&
place
)
:
place
(
place
),
executor
(
place
)
{
}
const
::
paddle
::
platform
::
Place
&
place
;
::
paddle
::
framework
::
Executor
executor
;
::
std
::
unique_ptr
<::
paddle
::
framework
::
ProgramDesc
>
main_program
;
::
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepare_context
;
details
::
TensorArrayBatchCleaner
tensor_array_batch_cleaner
;
};
std
::
unique_ptr
<
Context
>
_context
;
};
REGISTER_CLASS
(
Executor
,
SimpleExecutor
);
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.h
浏览文件 @
d7ee6ba1
...
...
@@ -42,18 +42,6 @@ protected:
};
REGISTER_REGISTERER
(
Executor
);
class
SimpleExecutor
:
public
Executor
{
public:
SimpleExecutor
();
virtual
~
SimpleExecutor
();
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
int
run
();
protected:
struct
Context
;
std
::
unique_ptr
<
Context
>
_context
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
0 → 100644
浏览文件 @
d7ee6ba1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include <string>
#include <unordered_map>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
AutoFileSystem
:
public
FileSystem
{
public:
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
override
{
_file_system
.
clear
();
if
(
config
&&
config
[
"file_systems"
]
&&
config
[
"file_systems"
].
Type
()
==
YAML
::
NodeType
::
Map
)
{
for
(
auto
&
prefix_fs
:
config
[
"file_systems"
])
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
)));
if
(
fs
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to create class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
return
-
1
;
}
if
(
fs
->
initialize
(
prefix_fs
.
second
,
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to initialize class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
return
0
;
}
_file_system
.
emplace
(
prefix_fs
.
first
.
as
<
std
::
string
>
(
""
),
std
::
move
(
fs
));
}
}
if
(
_file_system
.
find
(
"default"
)
==
_file_system
.
end
())
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
if
(
fs
==
nullptr
||
fs
->
initialize
(
YAML
::
Load
(
""
),
context
)
!=
0
)
{
return
-
1
;
}
_file_system
.
emplace
(
"default"
,
std
::
move
(
fs
));
}
return
0
;
}
std
::
shared_ptr
<
FILE
>
open_read
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
override
{
return
get_file_system
(
path
)
->
open_read
(
path
,
converter
);
}
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
override
{
return
get_file_system
(
path
)
->
open_write
(
path
,
converter
);
}
int64_t
file_size
(
const
std
::
string
&
path
)
override
{
return
get_file_system
(
path
)
->
file_size
(
path
);
}
void
remove
(
const
std
::
string
&
path
)
override
{
get_file_system
(
path
)
->
remove
(
path
);
}
std
::
vector
<
std
::
string
>
list
(
const
std
::
string
&
path
)
override
{
return
get_file_system
(
path
)
->
list
(
path
);
}
std
::
string
tail
(
const
std
::
string
&
path
)
override
{
return
get_file_system
(
path
)
->
tail
(
path
);
}
bool
exists
(
const
std
::
string
&
path
)
override
{
return
get_file_system
(
path
)
->
exists
(
path
);
}
void
mkdir
(
const
std
::
string
&
path
)
override
{
get_file_system
(
path
)
->
mkdir
(
path
);
}
FileSystem
*
get_file_system
(
const
std
::
string
&
path
)
{
auto
pos
=
path
.
find_first_of
(
":"
);
if
(
pos
!=
std
::
string
::
npos
)
{
auto
substr
=
path
.
substr
(
0
,
pos
+
1
);
auto
fs_it
=
_file_system
.
find
(
substr
);
if
(
fs_it
!=
_file_system
.
end
())
{
return
fs_it
->
second
.
get
();
}
}
VLOG
(
5
)
<<
"path: "
<<
path
<<
", select default file system"
;
return
_file_system
[
"default"
].
get
();
}
int
err_no
()
const
override
{
if
(
_err_no
==
0
)
{
for
(
const
auto
&
file_system
:
_file_system
)
{
if
(
file_system
.
second
->
err_no
()
!=
0
)
{
const_cast
<
int
&>
(
_err_no
)
=
-
1
;
break
;
}
}
}
return
FileSystem
::
err_no
();
}
void
reset_err_no
()
override
{
_err_no
=
0
;
for
(
auto
&
file_system
:
_file_system
)
{
file_system
.
second
->
reset_err_no
();
}
}
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
FileSystem
>>
_file_system
;
};
REGISTER_CLASS
(
FileSystem
,
AutoFileSystem
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/file_system.cc
0 → 100644
浏览文件 @
d7ee6ba1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include <string>
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
std
::
string
FileSystem
::
path_join
(
const
std
::
string
&
dir
,
const
std
::
string
&
path
)
{
if
(
dir
.
empty
())
{
return
path
;
}
if
(
dir
.
back
()
==
'/'
)
{
return
dir
+
path
;
}
return
dir
+
'/'
+
path
;
}
std
::
pair
<
std
::
string
,
std
::
string
>
FileSystem
::
path_split
(
const
std
::
string
&
path
)
{
size_t
pos
=
path
.
find_last_of
(
'/'
);
if
(
pos
==
std
::
string
::
npos
)
{
return
{
"."
,
path
};
}
return
{
path
.
substr
(
0
,
pos
),
path
.
substr
(
pos
+
1
)};
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/file_system.h
0 → 100644
浏览文件 @
d7ee6ba1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <cstdio>
#include <vector>
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include <yaml-cpp/yaml.h>
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
FileSystem
{
public:
FileSystem
()
{}
virtual
~
FileSystem
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
std
::
shared_ptr
<
FILE
>
open_read
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
=
0
;
virtual
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
=
0
;
virtual
int64_t
file_size
(
const
std
::
string
&
path
)
=
0
;
virtual
void
remove
(
const
std
::
string
&
path
)
=
0
;
virtual
std
::
vector
<
std
::
string
>
list
(
const
std
::
string
&
path
)
=
0
;
virtual
std
::
string
tail
(
const
std
::
string
&
path
)
=
0
;
virtual
bool
exists
(
const
std
::
string
&
path
)
=
0
;
virtual
void
mkdir
(
const
std
::
string
&
path
)
=
0
;
virtual
std
::
string
path_join
(
const
std
::
string
&
dir
,
const
std
::
string
&
path
);
virtual
std
::
pair
<
std
::
string
,
std
::
string
>
path_split
(
const
std
::
string
&
path
);
virtual
int
err_no
()
const
{
return
_err_no
;
}
inline
operator
bool
()
{
return
err_no
()
==
0
;
}
virtual
void
reset_err_no
()
{
_err_no
=
0
;
}
protected:
int
_err_no
=
0
;
};
REGISTER_REGISTERER
(
FileSystem
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
0 → 100644
浏览文件 @
d7ee6ba1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include <string>
#include <unordered_map>
#include <tuple>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/piece.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
HadoopFileSystem
:
public
FileSystem
{
public:
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
override
{
_buffer_size
=
config
[
"buffer_size"
].
as
<
size_t
>
(
0
);
_hdfs_command
=
config
[
"hdfs_command"
].
as
<
std
::
string
>
(
"hadoop fs"
);
_ugi
.
clear
();
if
(
config
[
"ugis"
]
&&
config
[
"ugis"
].
Type
()
==
YAML
::
NodeType
::
Map
)
{
for
(
const
auto
&
prefix_ugi
:
config
[
"ugis"
])
{
_ugi
.
emplace
(
prefix_ugi
.
first
.
as
<
std
::
string
>
(),
prefix_ugi
.
second
.
as
<
std
::
string
>
());
}
}
if
(
_ugi
.
find
(
"default"
)
==
_ugi
.
end
())
{
VLOG
(
2
)
<<
"fail to load default ugi"
;
return
-
1
;
}
return
0
;
}
std
::
shared_ptr
<
FILE
>
open_read
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
override
{
std
::
string
cmd
;
if
(
string
::
end_with
(
path
,
".gz"
))
{
cmd
=
string
::
format_string
(
"%s -text
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
}
else
{
cmd
=
string
::
format_string
(
"%s -cat
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
}
bool
is_pipe
=
true
;
shell_add_read_converter
(
cmd
,
is_pipe
,
converter
);
return
shell_open
(
cmd
,
is_pipe
,
"r"
,
_buffer_size
,
&
_err_no
);
}
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
override
{
std
::
string
cmd
=
string
::
format_string
(
"%s -put -
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
bool
is_pipe
=
true
;
if
(
string
::
end_with
(
path
,
".gz
\"
"
))
{
shell_add_write_converter
(
cmd
,
is_pipe
,
"gzip"
);
}
shell_add_write_converter
(
cmd
,
is_pipe
,
converter
);
return
shell_open
(
cmd
,
is_pipe
,
"w"
,
_buffer_size
,
&
_err_no
);
}
int64_t
file_size
(
const
std
::
string
&
path
)
override
{
_err_no
=
-
1
;
VLOG
(
2
)
<<
"not support"
;
return
0
;
}
void
remove
(
const
std
::
string
&
path
)
override
{
if
(
path
==
""
)
{
return
;
}
shell_execute
(
string
::
format_string
(
"%s -rmr %s &>/dev/null; true"
,
_hdfs_command
.
c_str
(),
path
.
c_str
()));
}
std
::
vector
<
std
::
string
>
list
(
const
std
::
string
&
path
)
override
{
if
(
path
==
""
)
{
return
{};
}
auto
paths
=
split_path
(
path
);
int
err_no
=
0
;
std
::
vector
<
std
::
string
>
list
;
do
{
err_no
=
0
;
std
::
shared_ptr
<
FILE
>
pipe
;
pipe
=
shell_popen
(
string
::
format_string
(
"%s -ls %s | ( grep ^- ; [ $? != 2 ] )"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
()),
"r"
,
&
err_no
);
string
::
LineFileReader
reader
;
list
.
clear
();
while
(
reader
.
getline
(
&*
pipe
))
{
std
::
vector
<
std
::
string
>
line
=
string
::
split_string
(
reader
.
get
());
if
(
line
.
size
()
!=
8
)
{
continue
;
}
list
.
push_back
(
get_prefix
(
paths
)
+
line
[
7
]);
}
}
while
(
err_no
==
-
1
);
return
list
;
}
std
::
string
tail
(
const
std
::
string
&
path
)
override
{
if
(
path
==
""
)
{
return
""
;
}
return
shell_get_command_output
(
string
::
format_string
(
"%s -text %s | tail -1 "
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
()));
}
bool
exists
(
const
std
::
string
&
path
)
override
{
std
::
string
test
=
shell_get_command_output
(
string
::
format_string
(
"%s -test -e %s ; echo $?"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
()));
if
(
string
::
trim_spaces
(
test
)
==
"0"
)
{
return
true
;
}
return
false
;
}
void
mkdir
(
const
std
::
string
&
path
)
override
{
if
(
path
==
""
)
{
return
;
}
shell_execute
(
string
::
format_string
(
"%s -mkdir %s; true"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
()));
}
std
::
string
hdfs_command
(
const
std
::
string
&
path
)
{
auto
paths
=
split_path
(
path
);
auto
it
=
_ugi
.
find
(
std
::
get
<
1
>
(
paths
).
ToString
());
if
(
it
!=
_ugi
.
end
())
{
return
hdfs_command_with_ugi
(
it
->
second
);
}
VLOG
(
5
)
<<
"path: "
<<
path
<<
", select default ugi"
;
return
hdfs_command_with_ugi
(
_ugi
[
"default"
]);
}
std
::
string
hdfs_command_with_ugi
(
std
::
string
ugi
)
{
return
string
::
format_string
(
"%s -Dhadoop.job.ugi=
\"
%s
\"
"
,
_hdfs_command
.
c_str
(),
ugi
.
c_str
());
}
private:
std
::
string
get_prefix
(
const
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>&
paths
)
{
if
(
std
::
get
<
1
>
(
paths
).
len
()
==
0
)
{
return
std
::
get
<
0
>
(
paths
).
ToString
();
}
return
std
::
get
<
0
>
(
paths
).
ToString
()
+
"//"
+
std
::
get
<
1
>
(
paths
).
ToString
();
}
// parse "xxx://abc.def:8756/user" as "xxx:", "abc.def:8756", "/user"
// parse "xxx:/user" as "xxx:", "", "/user"
// parse "xxx://abc.def:8756" as "xxx:", "abc.def:8756", ""
// parse "other" as "", "", "other"
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>
split_path
(
string
::
Piece
path
)
{
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>
result
{
string
::
SubStr
(
path
,
0
,
0
),
string
::
SubStr
(
path
,
0
,
0
),
path
};
auto
fs_pos
=
string
::
Find
(
path
,
':'
,
0
)
+
1
;
if
(
path
.
len
()
>
fs_pos
)
{
std
::
get
<
0
>
(
result
)
=
string
::
SubStr
(
path
,
0
,
fs_pos
);
path
=
string
::
SkipPrefix
(
path
,
fs_pos
);
if
(
string
::
HasPrefix
(
path
,
"//"
))
{
path
=
string
::
SkipPrefix
(
path
,
2
);
auto
end_pos
=
string
::
Find
(
path
,
'/'
,
0
);
if
(
end_pos
!=
string
::
Piece
::
npos
)
{
std
::
get
<
1
>
(
result
)
=
string
::
SubStr
(
path
,
0
,
end_pos
);
std
::
get
<
2
>
(
result
)
=
string
::
SkipPrefix
(
path
,
end_pos
);
}
else
{
std
::
get
<
1
>
(
result
)
=
path
;
}
}
else
{
std
::
get
<
2
>
(
result
)
=
path
;
}
}
return
result
;
}
size_t
_buffer_size
=
0
;
std
::
string
_hdfs_command
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
_ugi
;
};
REGISTER_CLASS
(
FileSystem
,
HadoopFileSystem
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc
0 → 100644
浏览文件 @
d7ee6ba1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include <string>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
LocalFileSystem
:
public
FileSystem
{
public:
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
override
{
_buffer_size
=
config
[
"buffer_size"
].
as
<
size_t
>
(
0
);
return
0
;
}
std
::
shared_ptr
<
FILE
>
open_read
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
override
{
std
::
string
cmd
=
path
;
bool
is_pipe
=
false
;
if
(
string
::
end_with
(
path
,
".gz"
))
{
shell_add_read_converter
(
cmd
,
is_pipe
,
"zcat"
);
}
shell_add_read_converter
(
cmd
,
is_pipe
,
converter
);
return
shell_open
(
cmd
,
is_pipe
,
"r"
,
_buffer_size
);
}
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
override
{
std
::
string
cmd
=
path
;
shell_execute
(
string
::
format_string
(
"mkdir -p $(dirname
\"
%s
\"
)"
,
path
.
c_str
()));
bool
is_pipe
=
false
;
if
(
string
::
end_with
(
path
,
".gz"
))
{
shell_add_write_converter
(
cmd
,
is_pipe
,
"gzip"
);
}
shell_add_write_converter
(
cmd
,
is_pipe
,
converter
);
return
shell_open
(
cmd
,
is_pipe
,
"w"
,
_buffer_size
);
}
int64_t
file_size
(
const
std
::
string
&
path
)
override
{
struct
stat
buf
;
if
(
0
!=
stat
(
path
.
c_str
(),
&
buf
))
{
LOG
(
FATAL
)
<<
"file stat not zero"
;
return
-
1
;
}
return
(
int64_t
)
buf
.
st_size
;
}
void
remove
(
const
std
::
string
&
path
)
override
{
if
(
path
==
""
)
{
return
;
}
shell_execute
(
string
::
format_string
(
"rm -rf %s"
,
path
.
c_str
()));
}
std
::
vector
<
std
::
string
>
list
(
const
std
::
string
&
path
)
override
{
if
(
path
==
""
)
{
return
{};
}
std
::
shared_ptr
<
FILE
>
pipe
;
pipe
=
shell_popen
(
string
::
format_string
(
"find %s -maxdepth 1 -type f"
,
path
.
c_str
()),
"r"
,
&
_err_no
);
string
::
LineFileReader
reader
;
std
::
vector
<
std
::
string
>
list
;
while
(
reader
.
getline
(
&*
pipe
))
{
list
.
push_back
(
reader
.
get
());
}
return
list
;
}
std
::
string
tail
(
const
std
::
string
&
path
)
override
{
if
(
path
==
""
)
{
return
""
;
}
return
shell_get_command_output
(
string
::
format_string
(
"tail -1 %s "
,
path
.
c_str
()));
}
bool
exists
(
const
std
::
string
&
path
)
override
{
std
::
string
test_f
=
shell_get_command_output
(
string
::
format_string
(
"[ -f %s ] ; echo $?"
,
path
.
c_str
()));
if
(
string
::
trim_spaces
(
test_f
)
==
"0"
)
{
return
true
;
}
std
::
string
test_d
=
shell_get_command_output
(
string
::
format_string
(
"[ -d %s ] ; echo $?"
,
path
.
c_str
()));
if
(
string
::
trim_spaces
(
test_d
)
==
"0"
)
{
return
true
;
}
return
false
;
}
void
mkdir
(
const
std
::
string
&
path
)
override
{
if
(
path
==
""
)
{
return
;
}
shell_execute
(
string
::
format_string
(
"mkdir -p %s"
,
path
.
c_str
()));
}
private:
size_t
_buffer_size
=
0
;
};
REGISTER_CLASS
(
FileSystem
,
LocalFileSystem
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/shell.cc
0 → 100644
浏览文件 @
d7ee6ba1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
void
shell_add_write_converter
(
std
::
string
&
path
,
bool
&
is_pipe
,
// NOLINT
const
std
::
string
&
converter
)
{
if
(
converter
==
""
)
{
return
;
}
if
(
!
is_pipe
)
{
path
=
string
::
format_string
(
"( %s ) >
\"
%s
\"
"
,
converter
.
c_str
(),
path
.
c_str
());
is_pipe
=
true
;
}
else
{
path
=
string
::
format_string
(
"%s | %s"
,
converter
.
c_str
(),
path
.
c_str
());
}
}
void
shell_add_read_converter
(
std
::
string
&
path
,
bool
&
is_pipe
,
const
std
::
string
&
converter
)
{
if
(
converter
==
""
)
{
return
;
}
if
(
!
is_pipe
)
{
path
=
string
::
format_string
(
"( %s ) <
\"
%s
\"
"
,
converter
.
c_str
(),
path
.
c_str
());
is_pipe
=
true
;
}
else
{
path
=
string
::
format_string
(
"%s | %s"
,
path
.
c_str
(),
converter
.
c_str
());
}
}
std
::
shared_ptr
<
FILE
>
shell_open
(
const
std
::
string
&
path
,
bool
is_pipe
,
const
std
::
string
&
mode
,
size_t
buffer_size
,
int
*
err_no
)
{
std
::
shared_ptr
<
FILE
>
fp
=
nullptr
;
if
(
!
is_pipe
)
{
fp
=
shell_fopen
(
path
,
mode
);
}
else
{
fp
=
shell_popen
(
path
,
mode
,
err_no
);
}
if
(
buffer_size
>
0
)
{
char
*
buffer
=
new
char
[
buffer_size
];
CHECK_EQ
(
0
,
setvbuf
(
&*
fp
,
buffer
,
_IOFBF
,
buffer_size
));
fp
=
{
&*
fp
,
[
fp
,
buffer
](
FILE
*
)
mutable
{
// NOLINT
CHECK
(
fp
.
unique
());
// NOLINT
fp
=
nullptr
;
delete
[]
buffer
;
}};
}
return
fp
;
}
std
::
shared_ptr
<
FILE
>
shell_fopen
(
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
{
#if defined _WIN32 || defined __APPLE__
return
nullptr
;
#else
if
(
shell_verbose
())
{
LOG
(
INFO
)
<<
"Opening file["
<<
path
<<
"] with mode["
<<
mode
<<
"]"
;
}
FILE
*
fp
;
if
(
!
(
fp
=
fopen
(
path
.
c_str
(),
mode
.
c_str
())))
{
LOG
(
FATAL
)
<<
"fopen fail, path["
<<
path
<<
"], mode["
<<
mode
<<
"]"
;
}
return
{
fp
,
[
path
](
FILE
*
fp
)
{
if
(
shell_verbose
())
{
LOG
(
INFO
)
<<
"Closing file["
<<
path
<<
"]"
;
}
if
(
0
!=
fclose
(
fp
))
{
LOG
(
FATAL
)
<<
"fclose fail, path["
<<
path
<<
"]"
;
}
}};
#endif
}
// Close all open file descriptors
// The implementation is async signal safe
// Mostly copy from CPython code
static
int
close_open_fds_internal
()
{
#if defined _WIN32 || defined __APPLE__
return
0
;
#else
struct
linux_dirent
{
long
d_ino
=
0
;
// NOLINT
off_t
d_off
;
unsigned
short
d_reclen
=
0
;
// NOLINT
char
d_name
[
256
];
};
int
dir_fd
=
-
1
;
if
((
dir_fd
=
open
(
"/proc/self/fd"
,
O_RDONLY
))
<
0
)
{
LOG
(
FATAL
)
<<
"proc/self/fd open fail"
;
return
-
1
;
}
char
buffer
[
sizeof
(
linux_dirent
)];
for
(;;)
{
int
bytes
=
0
;
if
((
bytes
=
syscall
(
SYS_getdents
,
dir_fd
,
reinterpret_cast
<
linux_dirent
*>
(
buffer
),
sizeof
(
buffer
)))
<
0
)
{
LOG
(
FATAL
)
<<
"syscall fail"
;
return
-
1
;
}
if
(
bytes
==
0
)
{
break
;
}
linux_dirent
*
entry
=
NULL
;
for
(
int
offset
=
0
;
offset
<
bytes
;
offset
+=
entry
->
d_reclen
)
{
entry
=
reinterpret_cast
<
linux_dirent
*>
(
buffer
+
offset
);
int
fd
=
0
;
const
char
*
s
=
entry
->
d_name
;
while
(
*
s
>=
'0'
&&
*
s
<=
'9'
)
{
fd
=
fd
*
10
+
(
*
s
-
'0'
);
s
++
;
}
if
(
s
!=
entry
->
d_name
&&
fd
!=
dir_fd
&&
fd
>=
3
)
{
close
(
fd
);
}
}
}
close
(
dir_fd
);
return
0
;
#endif
}
static
int
shell_popen_fork_internal
(
const
char
*
real_cmd
,
bool
do_read
,
int
parent_end
,
int
child_end
)
{
#if defined _WIN32 || defined __APPLE__
return
0
;
#else
int
child_pid
=
-
1
;
// Too frequent calls to fork() makes openmpi very slow. Use vfork() instead.
// But vfork() is very dangerous. Be careful.
if
((
child_pid
=
vfork
())
<
0
)
{
return
-
1
;
}
// The following code is async signal safe (No memory allocation, no access to
// global data, etc.)
if
(
child_pid
!=
0
)
{
return
child_pid
;
}
int
child_std_end
=
do_read
?
1
:
0
;
close
(
parent_end
);
if
(
child_end
!=
child_std_end
)
{
if
(
dup2
(
child_end
,
child_std_end
)
!=
child_std_end
)
{
return
-
1
;
}
close
(
child_end
);
}
close_open_fds_internal
();
if
(
execl
(
"/bin/bash"
,
"bash"
,
"-c"
,
real_cmd
,
NULL
)
<
0
)
{
return
-
1
;
}
exit
(
127
);
#endif
}
std
::
shared_ptr
<
FILE
>
shell_popen
(
const
std
::
string
&
cmd
,
const
std
::
string
&
mode
,
int
*
err_no
)
{
#if defined _WIN32 || defined __APPLE__
return
nullptr
;
#else
bool
do_read
=
mode
==
"r"
;
bool
do_write
=
mode
==
"w"
;
if
(
!
(
do_read
||
do_write
))
{
*
err_no
=
-
1
;
return
NULL
;
}
if
(
shell_verbose
())
{
LOG
(
INFO
)
<<
"Opening pipe["
<<
cmd
<<
"] with mode["
<<
mode
<<
"]"
;
}
std
::
string
real_cmd
=
"set -o pipefail; "
+
cmd
;
int
pipe_fds
[
2
];
if
(
pipe
(
pipe_fds
)
!=
0
)
{
*
err_no
=
-
1
;
return
NULL
;
}
int
parent_end
=
0
;
int
child_end
=
0
;
if
(
do_read
)
{
parent_end
=
pipe_fds
[
0
];
child_end
=
pipe_fds
[
1
];
}
else
if
(
do_write
)
{
parent_end
=
pipe_fds
[
1
];
child_end
=
pipe_fds
[
0
];
}
int
child_pid
=
shell_popen_fork_internal
(
real_cmd
.
c_str
(),
do_read
,
parent_end
,
child_end
);
close
(
child_end
);
fcntl
(
parent_end
,
F_SETFD
,
FD_CLOEXEC
);
FILE
*
fp
;
if
((
fp
=
fdopen
(
parent_end
,
mode
.
c_str
()))
==
NULL
)
{
*
err_no
=
-
1
;
return
NULL
;
}
return
{
fp
,
[
child_pid
,
cmd
,
err_no
](
FILE
*
fp
)
{
if
(
shell_verbose
())
{
LOG
(
INFO
)
<<
"Closing pipe["
<<
cmd
<<
"]"
;
}
if
(
fclose
(
fp
)
!=
0
)
{
*
err_no
=
-
1
;
}
int
wstatus
=
-
1
;
waitpid
(
child_pid
,
&
wstatus
,
0
);
if
(
wstatus
==
0
||
wstatus
==
(
128
+
SIGPIPE
)
*
256
||
(
wstatus
==
-
1
&&
errno
==
ECHILD
))
{
}
else
{
*
err_no
=
-
1
;
LOG
(
WARNING
)
<<
"status["
<<
wstatus
<<
"], cmd["
<<
cmd
<<
"]"
<<
", err_no["
<<
*
err_no
<<
"]"
;
}
if
(
wstatus
==
-
1
&&
errno
==
ECHILD
)
{
LOG
(
WARNING
)
<<
"errno is ECHILD"
;
}
}};
#endif
}
static
int
shell_p2open_fork_internal
(
const
char
*
real_cmd
,
int
pipein_fds
[
2
],
int
pipeout_fds
[
2
])
{
#if defined _WIN32 || defined __APPLE__
return
0
;
#else
int
child_pid
=
-
1
;
if
((
child_pid
=
fork
())
<
0
)
{
return
-
1
;
}
if
(
child_pid
!=
0
)
{
return
child_pid
;
}
close
(
pipein_fds
[
0
]);
close
(
pipeout_fds
[
1
]);
if
(
pipein_fds
[
1
]
!=
1
)
{
if
(
dup2
(
pipein_fds
[
1
],
1
)
!=
1
)
{
return
-
1
;
}
close
(
pipein_fds
[
1
]);
}
if
(
pipeout_fds
[
0
]
!=
0
)
{
if
(
dup2
(
pipeout_fds
[
0
],
0
)
!=
0
)
{
return
-
1
;
}
close
(
pipeout_fds
[
0
]);
}
close_open_fds_internal
();
if
(
execl
(
"/bin/sh"
,
"sh"
,
"-c"
,
real_cmd
,
NULL
)
<
0
)
{
return
-
1
;
}
exit
(
127
);
#endif
}
std
::
pair
<
std
::
shared_ptr
<
FILE
>
,
std
::
shared_ptr
<
FILE
>>
shell_p2open
(
const
std
::
string
&
cmd
)
{
#if defined _WIN32 || defined __APPLE__
return
{};
#else
if
(
shell_verbose
())
{
LOG
(
INFO
)
<<
"Opening bidirectional pipe["
<<
cmd
<<
"]"
;
}
std
::
string
real_cmd
=
"set -o pipefail; "
+
cmd
;
int
pipein_fds
[
2
];
int
pipeout_fds
[
2
];
if
(
pipe
(
pipein_fds
)
!=
0
)
{
return
{
NULL
,
NULL
};
}
if
(
pipe
(
pipeout_fds
)
!=
0
)
{
return
{
NULL
,
NULL
};
}
int
child_pid
=
shell_p2open_fork_internal
(
real_cmd
.
c_str
(),
pipein_fds
,
pipeout_fds
);
close
(
pipein_fds
[
1
]);
close
(
pipeout_fds
[
0
]);
fcntl
(
pipein_fds
[
0
],
F_SETFD
,
FD_CLOEXEC
);
fcntl
(
pipeout_fds
[
1
],
F_SETFD
,
FD_CLOEXEC
);
std
::
shared_ptr
<
int
>
child_life
=
{
NULL
,
[
child_pid
,
cmd
](
void
*
)
{
if
(
shell_verbose
())
{
LOG
(
INFO
)
<<
"Closing bidirectional pipe["
<<
cmd
<<
"]"
;
}
int
wstatus
,
ret
;
do
{
PCHECK
((
ret
=
waitpid
(
child_pid
,
&
wstatus
,
0
))
>=
0
||
(
ret
==
-
1
&&
errno
==
EINTR
));
}
while
(
ret
==
-
1
&&
errno
==
EINTR
);
PCHECK
(
wstatus
==
0
||
wstatus
==
(
128
+
SIGPIPE
)
*
256
||
(
wstatus
==
-
1
&&
errno
==
ECHILD
))
<<
"status["
<<
wstatus
<<
"], cmd["
<<
cmd
<<
"]"
;
if
(
wstatus
==
-
1
&&
errno
==
ECHILD
)
{
LOG
(
WARNING
)
<<
"errno is ECHILD"
;
}
}};
FILE
*
in_fp
;
PCHECK
((
in_fp
=
fdopen
(
pipein_fds
[
0
],
"r"
))
!=
NULL
);
FILE
*
out_fp
;
PCHECK
((
out_fp
=
fdopen
(
pipeout_fds
[
1
],
"w"
))
!=
NULL
);
return
{{
in_fp
,
[
child_life
](
FILE
*
fp
)
{
PCHECK
(
fclose
(
fp
)
==
0
);
}},
{
out_fp
,
[
child_life
](
FILE
*
fp
)
{
PCHECK
(
fclose
(
fp
)
==
0
);
}}};
#endif
}
std
::
string
shell_get_command_output
(
const
std
::
string
&
cmd
)
{
#if defined _WIN32 || defined __APPLE__
return
""
;
#else
int
err_no
=
0
;
do
{
if
(
err_no
==
-
1
)
{
sleep
(
10
);
}
err_no
=
0
;
std
::
shared_ptr
<
FILE
>
pipe
=
shell_popen
(
cmd
,
"r"
,
&
err_no
);
string
::
LineFileReader
reader
;
if
(
reader
.
getdelim
(
&*
pipe
,
0
))
{
pipe
=
nullptr
;
if
(
err_no
==
0
)
{
return
reader
.
get
();
}
}
}
while
(
err_no
==
-
1
);
return
""
;
#endif
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/io/shell.h
0 → 100644
浏览文件 @
d7ee6ba1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fcntl.h>
#include <sys/stat.h>
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/syscall.h>
#endif
#include <sys/types.h>
#ifndef _WIN32
#include <sys/wait.h>
#endif
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
inline
bool
&
shell_verbose_internal
()
{
static
bool
x
=
false
;
return
x
;
}
inline
bool
shell_verbose
()
{
return
shell_verbose_internal
();
}
inline
void
shell_set_verbose
(
bool
x
)
{
shell_verbose_internal
()
=
x
;
}
extern
std
::
shared_ptr
<
FILE
>
shell_fopen
(
const
std
::
string
&
path
,
const
std
::
string
&
mode
);
extern
std
::
shared_ptr
<
FILE
>
shell_popen
(
const
std
::
string
&
cmd
,
const
std
::
string
&
mode
,
int
*
err_no
);
extern
std
::
pair
<
std
::
shared_ptr
<
FILE
>
,
std
::
shared_ptr
<
FILE
>>
shell_p2open
(
const
std
::
string
&
cmd
);
inline
void
shell_execute
(
const
std
::
string
&
cmd
)
{
int
err_no
=
0
;
do
{
err_no
=
0
;
shell_popen
(
cmd
,
"w"
,
&
err_no
);
}
while
(
err_no
==
-
1
);
}
extern
std
::
string
shell_get_command_output
(
const
std
::
string
&
cmd
);
extern
void
shell_add_read_converter
(
std
::
string
&
path
,
bool
&
is_pipe
,
const
std
::
string
&
converter
);
extern
std
::
shared_ptr
<
FILE
>
shell_open
(
const
std
::
string
&
path
,
bool
is_pipe
,
const
std
::
string
&
mode
,
size_t
buffer_size
,
int
*
err_no
=
0
);
extern
void
shell_add_write_converter
(
std
::
string
&
path
,
bool
&
is_pipe
,
const
std
::
string
&
converter
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/model/epoch_donefile.txt
0 → 100644
浏览文件 @
d7ee6ba1
20190710 1562775817 afs:/user/feed/mlarch/feed_multiTarget_model/magnet_duration_model_new_label2/batch_model/20190710_18 21 18
20190710 1562779976 afs:/user/feed/mlarch/feed_multiTarget_model/magnet_duration_model_new_label2/batch_model/20190710_18 22 18
20190711 1562783841 afs:/user/feed/mlarch/feed_multiTarget_model/magnet_duration_model_new_label2/batch_model/20190711_0 1565625600 1565625600
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
浏览文件 @
d7ee6ba1
...
...
@@ -5,6 +5,8 @@
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
...
...
@@ -20,26 +22,45 @@ int InitEnvProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
YAML
::
Node
config
=
_context_ptr
->
trainer_config
;
//environment
std
::
string
env_class
=
config
[
"environment"
][
"environment_class"
].
as
<
std
::
string
>
();
auto
*
environment
=
CREATE_CLASS
(
RuntimeEnvironment
,
env_class
);
if
(
environment
->
initialize
(
config
[
"environment"
])
!=
0
)
{
context_ptr
->
environment
.
reset
(
CREATE_CLASS
(
RuntimeEnvironment
,
env_class
));
if
(
context_ptr
->
environment
->
initialize
(
config
[
"environment"
])
!=
0
)
{
return
-
1
;
}
//file_system
context_ptr
->
file_system
.
reset
(
CREATE_CLASS
(
FileSystem
,
"AutoFileSystem"
));
if
(
context_ptr
->
file_system
->
initialize
(
config
[
"io"
],
context_ptr
)
!=
0
)
{
return
-
1
;
}
context_ptr
->
environment
.
reset
(
environment
);
//epoch
std
::
string
epoch_class
=
config
[
"epoch"
][
"epoch_class"
].
as
<
std
::
string
>
();
auto
*
epoch
=
CREATE_CLASS
(
EpochAccessor
,
epoch_class
);
if
(
epoch
->
initialize
(
config
[
"epoch"
],
context_ptr
)
!=
0
)
{
context_ptr
->
epoch_accessor
.
reset
(
CREATE_CLASS
(
EpochAccessor
,
epoch_class
));
if
(
context_ptr
->
epoch_accessor
->
initialize
(
config
[
"epoch"
],
context_ptr
)
!=
0
)
{
return
-
1
;
}
//Dataset
context_ptr
->
dataset
.
reset
(
new
Dataset
());
if
(
context_ptr
->
dataset
->
initialize
(
config
[
"dataset"
],
context_ptr
)
!=
0
)
{
return
-
1
;
}
context_ptr
->
epoch_accessor
.
reset
(
epoch
);
VLOG
(
3
)
<<
"Env initialize success"
;
return
0
;
}
int
InitEnvProcess
::
run
()
{
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
VLOG
(
3
)
<<
"Trainer Resume From epoch:"
<<
epoch_accessor
->
current_epoch_id
();
auto
next_epoch_id
=
epoch_accessor
->
next_epoch_id
(
epoch_accessor
->
current_epoch_id
());
_context_ptr
->
dataset
->
pre_detect_data
(
next_epoch_id
);
//step 1. psserver init
//step2. psserver load
VLOG
(
3
)
<<
"Psserver Start Success"
;
//context_ptr->pslib_client()->load_model();
VLOG
(
3
)
<<
"Psserver Load Model Success"
;
return
0
;
}
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
d7ee6ba1
...
...
@@ -3,6 +3,7 @@
*Train样本
*/
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"
...
...
@@ -49,7 +50,7 @@ std::future<int> LearnerProcess::save_model(uint64_t epoch_id, int table_id, Mod
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
)
{
auto
*
environment
=
_context_ptr
->
environment
.
get
();
if
(
!
environment
->
is_master_node
())
{
if
(
!
environment
->
is_master_node
(
EnvironmentRole
::
WORKER
))
{
return
0
;
}
int
ret_size
=
0
;
...
...
@@ -69,16 +70,17 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
}
int
LearnerProcess
::
run
()
{
auto
*
dataset
=
_context_ptr
->
dataset
.
get
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
uint64_t
epoch_id
=
epoch_accessor
->
current_epoch_id
();
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
environment
->
log
(
Environment
Role
::
WORKER
,
Environment
LogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Resume traine with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//判断是否先dump出base
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
_all
(
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
while
(
true
)
{
epoch_accessor
->
next_epoch
();
...
...
@@ -87,16 +89,17 @@ int LearnerProcess::run() {
"train epoch_id:%d label:%s"
,
epoch_id
,
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//Step1. 等待样本ready
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
environment
->
log
(
Environment
Role
::
WORKER
,
Environment
LogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Start %s, wait data ready"
,
epoch_log_title
.
c_str
());
while
(
!
epoch_accessor
->
data_ready
(
epoch_id
)
)
{
while
(
dataset
->
epoch_data_status
(
epoch_id
)
!=
DatasetStatus
::
Ready
)
{
sleep
(
30
);
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
dataset
->
pre_detect_data
(
epoch_id
);
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data not ready, wait 30s"
,
epoch_log_title
.
c_str
());
}
environment
->
log
(
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
environment
->
log
(
Environment
Role
::
WORKER
,
Environment
LogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data is ready, start traning"
,
epoch_log_title
.
c_str
());
environment
->
barrier
_all
();
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
//Step2. 运行训练网络
bool
already_dump_inference_model
=
false
;
...
...
@@ -111,13 +114,13 @@ int LearnerProcess::run() {
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
train_threads
[
i
]
->
join
();
}
environment
->
barrier
_all
();
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
if
(
_threads_executor
[
0
][
i
]
->
is_dump_all_model
())
{
already_dump_inference_model
=
true
;
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
environment
->
barrier
_all
();
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
}
//Step3. Dump Model For Delta&&Checkpoint
...
...
@@ -126,7 +129,7 @@ int LearnerProcess::run() {
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpoint
);
environment
->
barrier
_all
(
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
//Step4. Output Monitor && RunStatus
//TODO
...
...
paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so
0 → 100755
浏览文件 @
d7ee6ba1
文件已添加
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
d7ee6ba1
...
...
@@ -12,6 +12,8 @@ namespace custom_trainer {
namespace
feed
{
class
Process
;
class
Dataset
;
class
FileSystem
;
class
EpochAccessor
;
enum
class
ModelSaveWay
{
...
...
@@ -35,10 +37,13 @@ class TrainerContext {
public:
YAML
::
Node
trainer_config
;
paddle
::
platform
::
CPUPlace
cpu_place
;
std
::
vector
<
TableMeta
>
params_table_list
;
std
::
shared_ptr
<
EpochAccessor
>
epoch_accessor
;
std
::
shared_ptr
<
RuntimeEnvironment
>
environment
;
std
::
vector
<
std
::
shared_ptr
<
Process
>>
process_list
;
std
::
shared_ptr
<
Dataset
>
dataset
;
//训练样本
std
::
shared_ptr
<
FileSystem
>
file_system
;
//文件操作辅助类
std
::
vector
<
TableMeta
>
params_table_list
;
//参数表
std
::
shared_ptr
<
EpochAccessor
>
epoch_accessor
;
//训练轮次控制
std
::
shared_ptr
<
RuntimeEnvironment
>
environment
;
//运行环境
std
::
vector
<
std
::
shared_ptr
<
Process
>>
process_list
;
//训练流程
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/main.cc
浏览文件 @
d7ee6ba1
...
...
@@ -8,5 +8,6 @@ int32_t main(int32_t argc, char** argv) {
::
google
::
InitGoogleLogging
(
argv
[
0
]);
::
testing
::
InitGoogleTest
(
&
argc
,
argv
);
::
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
return
RUN_ALL_TESTS
();
}
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
0 → 100644
浏览文件 @
d7ee6ba1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <fstream>
#include <gtest/gtest.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
namespace
{
const
char
test_data_dir
[]
=
"test_data"
;
}
class
DataReaderTest
:
public
testing
::
Test
{
public:
static
void
SetUpTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
{
std
::
ofstream
fout
(
fs
->
path_join
(
test_data_dir
,
"a.txt"
));
fout
<<
"abc 123456"
<<
std
::
endl
;
fout
<<
"def 234567"
<<
std
::
endl
;
fout
.
close
();
}
{
std
::
ofstream
fout
(
fs
->
path_join
(
test_data_dir
,
"b.txt"
));
fout
<<
"ghi 345678"
<<
std
::
endl
;
fout
<<
"jkl 456789"
<<
std
::
endl
;
fout
.
close
();
}
}
static
void
TearDownTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
remove
(
test_data_dir
);
}
virtual
void
SetUp
()
{
thread_num
=
omp_get_max_threads
();
omp_set_num_threads
(
1
);
fs
.
reset
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
context_ptr
.
reset
(
new
TrainerContext
());
}
virtual
void
TearDown
()
{
omp_set_num_threads
(
thread_num
);
fs
=
nullptr
;
context_ptr
=
nullptr
;
}
std
::
shared_ptr
<
TrainerContext
>
context_ptr
;
std
::
unique_ptr
<
FileSystem
>
fs
;
int
thread_num
=
1
;
};
TEST_F
(
DataReaderTest
,
LineDataParser
)
{
std
::
unique_ptr
<
DataParser
>
data_parser
(
CREATE_CLASS
(
DataParser
,
"LineDataParser"
));
ASSERT_NE
(
nullptr
,
data_parser
);
auto
config
=
YAML
::
Load
(
""
);
ASSERT_EQ
(
0
,
data_parser
->
initialize
(
config
,
context_ptr
));
DataItem
data_item
;
ASSERT_NE
(
0
,
data_parser
->
parse
(
std
::
string
(
"1abcd123456"
),
data_item
));
ASSERT_EQ
(
0
,
data_parser
->
parse
(
std
::
string
(
"2abc 123456"
),
data_item
));
ASSERT_STREQ
(
"2abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123456"
,
data_item
.
data
.
c_str
());
ASSERT_NE
(
0
,
data_parser
->
parse
(
"3abcd123456"
,
data_item
));
ASSERT_EQ
(
0
,
data_parser
->
parse
(
"4abc 123456"
,
data_item
));
ASSERT_STREQ
(
"4abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123456"
,
data_item
.
data
.
c_str
());
ASSERT_NE
(
0
,
data_parser
->
parse
(
"5abc 123456"
,
4
,
data_item
));
ASSERT_EQ
(
0
,
data_parser
->
parse
(
"6abc 123456"
,
5
,
data_item
));
ASSERT_STREQ
(
"6abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
""
,
data_item
.
data
.
c_str
());
ASSERT_EQ
(
0
,
data_parser
->
parse
(
"7abc 123456"
,
8
,
data_item
));
ASSERT_STREQ
(
"7abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123"
,
data_item
.
data
.
c_str
());
}
TEST_F
(
DataReaderTest
,
LineDataReader
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
2
,
data_file_list
.
size
());
std
::
sort
(
data_file_list
.
begin
(),
data_file_list
.
end
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"b.txt"
),
data_file_list
[
1
]);
ASSERT_FALSE
(
data_reader
->
is_data_ready
(
test_data_dir
));
std
::
ofstream
fout
(
fs
->
path_join
(
test_data_dir
,
"done_file"
));
fout
<<
"done"
;
fout
.
close
();
ASSERT_TRUE
(
data_reader
->
is_data_ready
(
test_data_dir
));
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123456"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"def"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"234567"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"ghi"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"345678"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"jkl"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"456789"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
TEST_F
(
DataReaderTest
,
LineDataReader_filename_prefix
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"filename_prefix: a"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123456"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"def"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"234567"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
TEST_F
(
DataReaderTest
,
LineDataReader_FileSystem
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"filename_prefix: a
\n
"
"file_system:
\n
"
" class: AutoFileSystem
\n
"
" file_systems:
\n
"
" 'afs:': &HDFS
\n
"
" class: HadoopFileSystem
\n
"
" hdfs_command: 'hadoop fs'
\n
"
" ugis:
\n
"
" 'default': 'feed_video,D3a0z8'
\n
"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'
\n
"
"
\n
"
" 'hdfs:': *HDFS
\n
"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
{
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123456"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"def"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"234567"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
{
char
test_hadoop_dir
[]
=
"afs://xingtian.afs.baidu.com:9902/user/feed_video/user/rensilin/paddle_trainer_test_dir"
;
ASSERT_TRUE
(
data_reader
->
is_data_ready
(
test_hadoop_dir
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_hadoop_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_hadoop_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_hadoop_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"hello"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"world"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"hello"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"hadoop"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc
0 → 100644
浏览文件 @
d7ee6ba1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <fstream>
#include <algorithm>
#include <gtest/gtest.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
namespace
{
const
char
test_data_dir
[]
=
"test_data"
;
}
class
DataReaderOmpTest
:
public
testing
::
Test
{
public:
static
void
SetUpTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
std_items
.
clear
();
sorted_std_items
.
clear
();
for
(
char
c
=
'a'
;
c
<=
'z'
;
++
c
)
{
DataItem
item
;
item
.
id
=
c
;
item
.
data
=
std
::
to_string
(
c
-
'a'
);
std
::
ofstream
fout
(
fs
->
path_join
(
test_data_dir
,
string
::
format_string
(
"%c.txt"
,
c
)));
fout
<<
item
.
id
<<
" "
<<
item
.
data
<<
std
::
endl
;
fout
.
close
();
sorted_std_items
.
push_back
(
std
::
move
(
item
));
}
for
(
const
auto
&
filename
:
fs
->
list
(
test_data_dir
))
{
std
::
ifstream
fin
(
filename
);
DataItem
item
;
fin
>>
item
.
id
>>
item
.
data
;
fin
.
close
();
std_items
.
push_back
(
std
::
move
(
item
));
}
}
static
void
TearDownTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
remove
(
test_data_dir
);
}
virtual
void
SetUp
()
{
thread_num
=
omp_get_max_threads
();
omp_set_num_threads
(
1
);
fs
.
reset
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
context_ptr
.
reset
(
new
TrainerContext
());
}
virtual
void
TearDown
()
{
omp_set_num_threads
(
thread_num
);
fs
=
nullptr
;
context_ptr
=
nullptr
;
}
static
bool
is_same
(
const
std
::
vector
<
DataItem
>&
a
,
const
std
::
vector
<
DataItem
>&
b
)
{
int
a_size
=
a
.
size
();
if
(
a_size
!=
b
.
size
())
{
return
false
;
}
for
(
int
i
=
0
;
i
<
a_size
;
++
i
)
{
if
(
a
[
i
].
id
!=
b
[
i
].
id
||
a
[
i
].
data
!=
b
[
i
].
data
)
{
return
false
;
}
}
return
true
;
}
static
bool
is_same_with_std_items
(
const
std
::
vector
<
DataItem
>&
items
)
{
return
is_same
(
items
,
std_items
);
}
static
bool
is_same_with_sorted_std_items
(
const
std
::
vector
<
DataItem
>&
items
)
{
return
is_same
(
items
,
sorted_std_items
);
}
static
std
::
vector
<
DataItem
>
std_items
;
static
std
::
vector
<
DataItem
>
sorted_std_items
;
std
::
shared_ptr
<
TrainerContext
>
context_ptr
;
std
::
unique_ptr
<
FileSystem
>
fs
;
int
thread_num
=
1
;
const
int
n_run
=
5
;
};
std
::
vector
<
DataItem
>
DataReaderOmpTest
::
std_items
;
std
::
vector
<
DataItem
>
DataReaderOmpTest
::
sorted_std_items
;
TEST_F
(
DataReaderOmpTest
,
LineDataReaderSingleThread
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
const
int
std_items_size
=
std_items
.
size
();
ASSERT_EQ
(
std_items_size
,
data_file_list
.
size
());
for
(
int
i
=
0
;
i
<
std_items_size
;
++
i
)
{
ASSERT_EQ
(
string
::
format_string
(
"%s/%s.txt"
,
test_data_dir
,
std_items
[
i
].
id
.
c_str
()),
data_file_list
[
i
]);
}
int
same_count
=
0
;
for
(
int
i
=
0
;
i
<
n_run
;
++
i
)
{
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
std
::
vector
<
DataItem
>
items
;
channel
->
ReadAll
(
items
);
if
(
is_same_with_std_items
(
items
))
{
++
same_count
;
}
}
// n_run 次都相同
ASSERT_EQ
(
n_run
,
same_count
);
}
TEST_F
(
DataReaderOmpTest
,
LineDataReaderMuiltThread
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
const
int
std_items_size
=
std_items
.
size
();
ASSERT_EQ
(
std_items_size
,
data_file_list
.
size
());
for
(
int
i
=
0
;
i
<
std_items_size
;
++
i
)
{
ASSERT_EQ
(
string
::
format_string
(
"%s/%s.txt"
,
test_data_dir
,
std_items
[
i
].
id
.
c_str
()),
data_file_list
[
i
]);
}
ASSERT_FALSE
(
data_reader
->
is_data_ready
(
test_data_dir
));
std
::
ofstream
fout
(
fs
->
path_join
(
test_data_dir
,
"done_file"
));
fout
<<
"done"
;
fout
.
close
();
ASSERT_TRUE
(
data_reader
->
is_data_ready
(
test_data_dir
));
int
same_count
=
0
;
int
sort_same_count
=
0
;
for
(
int
i
=
0
;
i
<
n_run
;
++
i
)
{
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
omp_set_num_threads
(
4
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
std
::
vector
<
DataItem
>
items
;
channel
->
ReadAll
(
items
);
if
(
is_same_with_std_items
(
items
))
{
++
same_count
;
}
std
::
sort
(
items
.
begin
(),
items
.
end
(),
[]
(
const
DataItem
&
a
,
const
DataItem
&
b
)
{
return
a
.
id
<
b
.
id
;
});
if
(
is_same_with_sorted_std_items
(
items
))
{
++
sort_same_count
;
}
}
// n_run次有不同的(证明是多线程)
ASSERT_EQ
(
4
,
omp_get_max_threads
());
ASSERT_GT
(
n_run
,
same_count
);
// 但排序后都是相同的
ASSERT_EQ
(
n_run
,
sort_same_count
);
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
浏览文件 @
d7ee6ba1
...
...
@@ -13,66 +13,118 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <fstream>
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
TEST
(
testSimpleExecutor
,
initialize
)
{
SimpleExecutor
execute
;
auto
context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
YAML
::
Node
config
=
YAML
::
Load
(
"[1, 2, 3]"
);
ASSERT_NE
(
0
,
execute
.
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
"{startup_program: ./data/startup_program, main_program: ./data/main_program}"
);
ASSERT_EQ
(
0
,
execute
.
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
"{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}"
);
ASSERT_EQ
(
0
,
execute
.
initialize
(
config
,
context_ptr
));
namespace
{
const
char
test_data_dir
[]
=
"test_data"
;
const
char
main_program_path
[]
=
"test_data/main_program"
;
const
char
startup_program_path
[]
=
"test_data/startup_program"
;
}
float
uniform
(
float
min
,
float
max
)
{
float
result
=
(
float
)
rand
()
/
RAND_MAX
;
return
min
+
result
*
(
max
-
min
);
}
class
SimpleExecutorTest
:
public
testing
::
Test
{
public:
static
void
SetUpTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
{
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
startup_program
(
new
paddle
::
framework
::
ProgramDesc
());
std
::
ofstream
fout
(
startup_program_path
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
ASSERT_TRUE
(
fout
);
fout
<<
startup_program
->
Proto
()
->
SerializeAsString
();
fout
.
close
();
}
{
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
main_program
(
new
paddle
::
framework
::
ProgramDesc
());
auto
load_block
=
main_program
->
MutableBlock
(
0
);
framework
::
OpDesc
*
op
=
load_block
->
AppendOp
();
op
->
SetType
(
"mean"
);
op
->
SetInput
(
"X"
,
{
"x"
});
op
->
SetOutput
(
"Out"
,
{
"mean"
});
op
->
CheckAttrs
();
std
::
ofstream
fout
(
main_program_path
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
ASSERT_TRUE
(
fout
);
fout
<<
main_program
->
Proto
()
->
SerializeAsString
();
fout
.
close
();
}
}
void
next_batch
(
int
batch_size
,
const
paddle
::
platform
::
Place
&
place
,
paddle
::
framework
::
LoDTensor
*
x_tensor
,
paddle
::
framework
::
LoDTensor
*
y_tensor
)
{
x_tensor
->
Resize
({
batch_size
,
2
});
auto
x_data
=
x_tensor
->
mutable_data
<
float
>
(
place
);
static
void
TearDownTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
remove
(
test_data_dir
);
}
y_tensor
->
Resize
({
batch_size
,
1
});
auto
y_data
=
y_tensor
->
mutable_data
<
float
>
(
place
);
virtual
void
SetUp
()
{
context_ptr
.
reset
(
new
TrainerContext
());
}
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
x_data
[
i
*
2
]
=
uniform
(
-
2
,
2
);
x_data
[
i
*
2
+
1
]
=
uniform
(
-
2
,
2
);
float
dis
=
x_data
[
i
*
2
]
*
x_data
[
i
*
2
]
+
x_data
[
i
*
2
+
1
]
*
x_data
[
i
*
2
+
1
];
y_data
[
i
]
=
dis
<
1.0
?
1.0
:
0.0
;
virtual
void
TearDown
()
{
context_ptr
=
nullptr
;
}
std
::
shared_ptr
<
TrainerContext
>
context_ptr
;
};
TEST_F
(
SimpleExecutorTest
,
initialize
)
{
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_CLASS
(
Executor
,
"SimpleExecutor"
));
ASSERT_NE
(
nullptr
,
executor
);
YAML
::
Node
config
=
YAML
::
Load
(
"[1, 2, 3]"
);
ASSERT_NE
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
string
::
format_string
(
"{startup_program: %s, main_program: %s}"
,
startup_program_path
,
main_program_path
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
string
::
format_string
(
"{thread_num: 2, startup_program: %s, main_program: %s}"
,
startup_program_path
,
main_program_path
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
}
TEST
(
testSimpleExecutor
,
run
)
{
SimpleExecutor
execute
;
auto
context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
auto
config
=
YAML
::
Load
(
"{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}"
);
ASSERT_EQ
(
0
,
execute
.
initialize
(
config
,
context_ptr
));
TEST_F
(
SimpleExecutorTest
,
run
)
{
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_CLASS
(
Executor
,
"SimpleExecutor"
));
ASSERT_NE
(
nullptr
,
executor
);
auto
config
=
YAML
::
Load
(
string
::
format_string
(
"{thread_num: 2, startup_program: %s, main_program: %s}"
,
startup_program_path
,
main_program_path
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
auto
x_var
=
execut
e
.
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"x"
);
auto
y_var
=
execute
.
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"y
"
);
auto
x_var
=
execut
or
->
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"x"
);
executor
->
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean
"
);
ASSERT_NE
(
nullptr
,
x_var
);
ASSERT_NE
(
nullptr
,
y_var
);
next_batch
(
1024
,
context_ptr
->
cpu_place
,
x_var
,
y_var
);
int
x_len
=
10
;
x_var
->
Resize
({
1
,
x_len
});
auto
x_data
=
x_var
->
mutable_data
<
float
>
(
context_ptr
->
cpu_place
);
std
::
cout
<<
"x: "
;
for
(
int
i
=
0
;
i
<
x_len
;
++
i
)
{
x_data
[
i
]
=
i
;
std
::
cout
<<
i
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
ASSERT_EQ
(
0
,
execut
e
.
run
());
ASSERT_EQ
(
0
,
execut
or
->
run
());
auto
loss_var
=
execute
.
var
<::
paddle
::
framework
::
LoDTensor
>
(
"loss"
);
auto
loss
=
loss_var
.
data
<
float
>
()[
0
];
std
::
cout
<<
"loss: "
<<
loss
<<
std
::
endl
;
auto
mean_var
=
executor
->
var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
auto
mean
=
mean_var
.
data
<
float
>
()[
0
];
std
::
cout
<<
"mean: "
<<
mean
<<
std
::
endl
;
ASSERT_NEAR
(
4.5
,
mean
,
1e-9
);
}
}
// namespace feed
...
...
release.bcloud
浏览文件 @
d7ee6ba1
...
...
@@ -3,3 +3,9 @@ mkdir -p so
cp
baidu_third-party_mklml/so/
*
so
rm
-rf
baidu_third-party_mklml
cp
baidu_third-party_openmpi/so/
*
so
rm
-rf
baidu_third-party_openmpi
rm
lib/libfake_paddle_proto.a
rmdir
lib 2>/dev/null
||
:
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录