Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
f6bd8cfd
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f6bd8cfd
编写于
8月 16, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
for mpi trainer
上级
d7ee6ba1
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
391 addition
and
129 deletion
+391
-129
BCLOUD
BCLOUD
+1
-0
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+1
-1
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
.../train/custom_trainer/feed/accessor/input_data_accessor.h
+31
-0
paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc
...ain/custom_trainer/feed/accessor/sparse_input_accessor.cc
+51
-0
paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
...e/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
+80
-0
paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h
...le/fluid/train/custom_trainer/feed/common/pslib_warpper.h
+49
-0
paddle/fluid/train/custom_trainer/feed/common/registerer.cc
paddle/fluid/train/custom_trainer/feed/common/registerer.cc
+3
-3
paddle/fluid/train/custom_trainer/feed/common/registerer.h
paddle/fluid/train/custom_trainer/feed/common/registerer.h
+9
-9
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
...d/train/custom_trainer/feed/common/runtime_environment.cc
+11
-2
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
...id/train/custom_trainer/feed/common/runtime_environment.h
+19
-14
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
...le/fluid/train/custom_trainer/feed/dataset/data_reader.cc
+5
-5
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+2
-2
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
...id/train/custom_trainer/feed/dataset/dataset_container.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/executor/executor.h
paddle/fluid/train/custom_trainer/feed/executor/executor.h
+1
-1
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
...le/fluid/train/custom_trainer/feed/io/auto_file_system.cc
+3
-3
paddle/fluid/train/custom_trainer/feed/io/file_system.h
paddle/fluid/train/custom_trainer/feed/io/file_system.h
+1
-1
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
.../fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc
...e/fluid/train/custom_trainer/feed/io/local_file_system.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/main.cc
paddle/fluid/train/custom_trainer/feed/main.cc
+45
-16
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
+1
-1
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
...uid/train/custom_trainer/feed/process/init_env_process.cc
+8
-12
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+41
-35
paddle/fluid/train/custom_trainer/feed/process/process.cc
paddle/fluid/train/custom_trainer/feed/process/process.cc
+2
-2
paddle/fluid/train/custom_trainer/feed/process/process.h
paddle/fluid/train/custom_trainer/feed/process/process.h
+1
-1
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+2
-0
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
...id/train/custom_trainer/feed/unit_test/test_datareader.cc
+8
-7
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc
...rain/custom_trainer/feed/unit_test/test_datareader_omp.cc
+6
-5
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
...luid/train/custom_trainer/feed/unit_test/test_executor.cc
+5
-4
未找到文件。
BCLOUD
浏览文件 @
f6bd8cfd
...
...
@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/paddlepaddle/pslib@master@git_branch')
CONFIGS('third-64/gtest@base')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
f6bd8cfd
...
...
@@ -88,7 +88,7 @@ namespace feed {
return
""
;
}
REGIST
ER
_CLASS
(
EpochAccessor
,
HourlyEpochAccessor
);
REGIST_CLASS
(
EpochAccessor
,
HourlyEpochAccessor
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
浏览文件 @
f6bd8cfd
...
...
@@ -60,7 +60,7 @@ protected:
std
::
vector
<
std
::
string
>
_done_status
;
//当前完成状态,统一存成string
};
REGIST
ER
_REGISTERER
(
EpochAccessor
);
REGIST_REGISTERER
(
EpochAccessor
);
class
HourlyEpochAccessor
:
public
EpochAccessor
{
public:
...
...
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
0 → 100644
浏览文件 @
f6bd8cfd
#pragma once
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
DataInputAccessor
:
public
Accessor
{
public:
DataInputAccessor
()
{}
virtual
~
DataInputAccessor
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
// 前向, 一般用于填充输入,在训练网络执行前调用
virtual
int32_t
forward
(
const
SampleInstance
*
samples
,
::
paddle
::
framework
::
Scope
*
scope
,
size_t
table_id
,
size_t
num
)
=
0
;
// 后向,一般用于更新梯度,在训练网络执行后调用
virtual
int32_t
backward
(
const
SampleInstance
*
samples
,
::
paddle
::
framework
::
Scope
*
scope
,
size_t
table_id
,
size_t
num
)
=
0
;
protected:
};
REGIST_REGISTERER
(
DataInputAccessor
);
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc
0 → 100644
浏览文件 @
f6bd8cfd
#include <vector>
#include <utility>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
CommonSparseInputAccessor
:
public
DataInputAccessor
{
public:
CommonSparseInputAccessor
()
{}
virtual
~
CommonSparseInputAccessor
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
CHECK
(
config
[
"sparse_input"
]
&&
config
[
"sparse_input"
].
Type
()
==
YAML
::
NodeType
::
Map
);
for
(
auto
&
input
:
config
[
"sparse_input"
])
{
std
::
pair
<
std
::
string
,
std
::
vector
<
uint16_t
>>
sparse_slots
;
sparse_slots
.
first
=
input
.
first
.
as
<
std
::
string
>
();
std
::
string
slots_str
=
input
.
second
[
"slots"
].
as
<
std
::
string
>
();
std
::
vector
<
std
::
string
>
slots
=
paddle
::
string
::
split_string
(
slots_str
,
","
);
for
(
int
i
=
0
;
i
<
slots
.
size
();
++
i
)
{
sparse_slots
.
second
.
push_back
((
uint16_t
)
atoi
(
slots
[
i
].
c_str
()));
}
}
return
0
;
}
// 取sparse数据
virtual
int32_t
forward
(
const
SampleInstance
*
samples
,
::
paddle
::
framework
::
Scope
*
scope
,
size_t
table_id
,
size_t
num
)
{
// pull
return
0
;
}
// 更新spare数据
virtual
int32_t
backward
(
const
SampleInstance
*
samples
,
::
paddle
::
framework
::
Scope
*
scope
,
size_t
table_id
,
size_t
num
)
{
return
0
;
}
protected:
// 输入层列表
// <data_name, slot_id_list>
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
uint16_t
>
>
>
_x_variables
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc
0 → 100644
浏览文件 @
f6bd8cfd
#include <fcntl.h>
#include <fstream>
#include <sstream>
#include "json2pb/json_to_pb.h"
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
int
PSlib
::
initialize
(
const
std
::
string
&
conf_path
,
RuntimeEnvironment
*
environment
,
EnvironmentRole
role
)
{
init_gflag
();
int
file_descriptor
=
open
(
conf_path
.
c_str
(),
O_RDONLY
);
if
(
file_descriptor
==
-
1
){
LOG
(
ERROR
)
<<
"FATAL: cant open "
<<
conf_path
;
return
-
1
;
}
google
::
protobuf
::
io
::
FileInputStream
fileInput
(
file_descriptor
);
if
(
!
google
::
protobuf
::
TextFormat
::
Parse
(
&
fileInput
,
&
_ps_param
))
{
LOG
(
ERROR
)
<<
"FATAL: fail to parse "
<<
conf_path
;
return
-
1
;
}
close
(
file_descriptor
);
init_server
(
role
);
init_client
(
EnvironmentRole
::
ALL
);
return
0
;
}
int
PSlib
::
init_server
(
EnvironmentRole
role
)
{
if
(
role
==
EnvironmentRole
::
PSERVER
)
{
_server_ptr
.
reset
(
paddle
::
ps
::
PSServerFactory
::
create
(
_ps_param
));
_server_ptr
->
configure
(
_ps_param
,
*
(
_environment
->
ps_environment
()),
_environment
->
rank_id
(
role
));
_server_ptr
->
start
();
}
_environment
->
ps_environment
()
->
gather_ps_servers
();
return
0
;
}
int
PSlib
::
init_client
(
EnvironmentRole
role
)
{
_client_ptr
.
reset
(
paddle
::
ps
::
PSClientFactory
::
create
(
_ps_param
));
_client_ptr
->
configure
(
_ps_param
,
*
(
_environment
->
ps_environment
()),
_environment
->
rank_id
(
role
));
return
0
;
}
paddle
::
ps
::
PSServer
*
PSlib
::
ps_server
()
{
return
_server_ptr
.
get
();
}
paddle
::
ps
::
PSClient
*
PSlib
::
ps_client
()
{
return
_client_ptr
.
get
();
}
paddle
::
PSParameter
*
PSlib
::
get_param
()
{
return
&
_ps_param
;
}
void
PSlib
::
init_gflag
()
{
int
cnt
=
4
;
std
::
shared_ptr
<
char
*>
params
(
new
char
*
[
cnt
]);
char
**
params_ptr
=
params
.
get
();
char
p0
[]
=
"exe default"
;
char
p1
[]
=
"-max_body_size=314217728"
;
char
p2
[]
=
"-bthread_concurrency=40"
;
char
p3
[]
=
"-socket_max_unwritten_bytes=2048000000"
;
params_ptr
[
0
]
=
p0
;
params_ptr
[
1
]
=
p1
;
params_ptr
[
2
]
=
p2
;
params_ptr
[
3
]
=
p3
;
::
google
::
ParseCommandLineFlags
(
&
cnt
,
&
params_ptr
,
true
);
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h
0 → 100644
浏览文件 @
f6bd8cfd
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "communicate/ps_server.h"
#include "communicate/ps_client.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
RuntimeEnvironment
;
enum
class
EnvironmentRole
;
class
PSlib
{
public:
PSlib
()
{}
virtual
~
PSlib
()
{}
int
initialize
(
const
std
::
string
&
conf_path
,
RuntimeEnvironment
*
environment
,
EnvironmentRole
role
);
virtual
paddle
::
ps
::
PSServer
*
ps_server
();
virtual
paddle
::
ps
::
PSClient
*
ps_client
();
virtual
paddle
::
PSParameter
*
get_param
();
private:
void
init_gflag
();
virtual
int
init_server
(
EnvironmentRole
role
);
virtual
int
init_client
(
EnvironmentRole
role
);
paddle
::
PSParameter
_ps_param
;
RuntimeEnvironment
*
_environment
;
std
::
shared_ptr
<
paddle
::
ps
::
PSServer
>
_server_ptr
;
std
::
shared_ptr
<
paddle
::
ps
::
PSClient
>
_client_ptr
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/common/registerer.cc
浏览文件 @
f6bd8cfd
...
...
@@ -3,12 +3,12 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
BaseClassMap
&
global_factory_map
()
{
BaseClassMap
&
global_
reg_
factory_map
()
{
static
BaseClassMap
*
base_class
=
new
BaseClassMap
();
return
*
base_class
;
}
BaseClassMap
&
global_factory_map_cpp
()
{
return
global_factory_map
();
BaseClassMap
&
global_
reg_
factory_map_cpp
()
{
return
global_
reg_
factory_map
();
}
}
// feed
...
...
paddle/fluid/train/custom_trainer/feed/common/registerer.h
浏览文件 @
f6bd8cfd
...
...
@@ -63,23 +63,23 @@ typedef std::map<std::string, FactoryMap> BaseClassMap;
#ifdef __cplusplus
extern
"C"
{
#endif
BaseClassMap
&
global_factory_map
();
BaseClassMap
&
global_
reg_
factory_map
();
#ifdef __cplusplus
}
#endif
BaseClassMap
&
global_factory_map_cpp
();
BaseClassMap
&
global_
reg_
factory_map_cpp
();
#define REGIST
ER
_REGISTERER(base_class) \
#define REGIST_REGISTERER(base_class) \
class base_class ## Registerer { \
public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \
if (global_factory_map_cpp().find(#base_class) \
== global_factory_map_cpp().end()) { \
if (global_
reg_
factory_map_cpp().find(#base_class) \
== global_
reg_
factory_map_cpp().end()) { \
LOG(ERROR) << "Can't Find BaseClass For CreateClass with:" << #base_class; \
return NULL; \
} \
FactoryMap &map = global_factory_map_cpp()[#base_class]; \
FactoryMap &map = global_
reg_
factory_map_cpp()[#base_class]; \
FactoryMap::iterator iter = map.find(name); \
if (iter == map.end()) { \
LOG(ERROR) << "Can't Find Class For Create with:" << name; \
...
...
@@ -90,7 +90,7 @@ BaseClassMap& global_factory_map_cpp();
} \
};
#define REGIST
ER
_CLASS(clazz, name) \
#define REGIST_CLASS(clazz, name) \
class ObjectFactory##name : public ObjectFactory { \
public: \
Any NewInstance() { \
...
...
@@ -98,14 +98,14 @@ BaseClassMap& global_factory_map_cpp();
} \
}; \
void register_factory_##name() { \
FactoryMap &map = global_factory_map_cpp()[#clazz]; \
FactoryMap &map = global_
reg_
factory_map_cpp()[#clazz]; \
if (map.find(#name) == map.end()) { \
map[#name] = new ObjectFactory##name(); \
} \
} \
void register_factory_##name() __attribute__((constructor));
#define CREATE_
CLASS
(base_class, name) \
#define CREATE_
INSTANCE
(base_class, name) \
base_class##Registerer::CreateInstanceByName(name)
}
//namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc
浏览文件 @
f6bd8cfd
...
...
@@ -44,6 +44,11 @@ public:
set_role
(
EnvironmentRole
::
ALL
);
return
0
;
}
virtual
paddle
::
ps
::
PSEnvironment
*
ps_environment
()
{
static
paddle
::
ps
::
MpiPSEnvironment
ps_environment
;
return
&
ps_environment
;
}
virtual
uint32_t
rank_id
(
EnvironmentRole
role
)
{
return
mpi_node_info
(
role
).
rank_id
;
...
...
@@ -95,7 +100,7 @@ protected:
private:
std
::
vector
<
MpiNodeInfo
>
_roles_node_info
;
};
REGIST
ER
_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
REGIST_CLASS
(
RuntimeEnvironment
,
MPIRuntimeEnvironment
);
//用于本地模式单机训练
class
LocalRuntimeEnvironment
:
public
RuntimeEnvironment
{
...
...
@@ -108,6 +113,10 @@ public:
virtual
int
wireup
()
{
return
0
;
}
virtual
paddle
::
ps
::
PSEnvironment
*
ps_environment
()
{
static
paddle
::
ps
::
LocalPSEnvironment
ps_environment
;
return
&
ps_environment
;
}
virtual
uint32_t
rank_id
(
EnvironmentRole
role
)
{
return
0
;
}
...
...
@@ -129,7 +138,7 @@ protected:
VLOG
(
static_cast
<
int
>
(
level
))
<<
log_str
;
}
};
REGIST
ER
_CLASS
(
RuntimeEnvironment
,
LocalRuntimeEnvironment
);
REGIST_CLASS
(
RuntimeEnvironment
,
LocalRuntimeEnvironment
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h
浏览文件 @
f6bd8cfd
...
...
@@ -6,6 +6,7 @@
*/
#pragma once
#include <yaml-cpp/yaml.h>
#include "communicate/ps_env.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
...
...
@@ -14,6 +15,8 @@ namespace paddle {
namespace
custom_trainer
{
namespace
feed
{
class
paddle
::
ps
::
PSEnvironment
;
enum
class
EnvironmentLogLevel
{
FATAL
=
0
,
ERROR
=
1
,
...
...
@@ -38,41 +41,43 @@ class RuntimeEnvironment {
public:
RuntimeEnvironment
();
virtual
~
RuntimeEnvironment
();
//配置初始化
//
配置初始化
virtual
int
initialize
(
YAML
::
Node
config
)
=
0
;
//设置role
//
设置role
virtual
int
set_role
(
EnvironmentRole
role
)
=
0
;
//环境初始化,会在所有依赖模块initialize后调用
//
环境初始化,会在所有依赖模块initialize后调用
virtual
int
wireup
()
=
0
;
//多线程可调用接口 Start
//当前环境rank_idx
//
多线程可调用接口 Start
//
当前环境rank_idx
virtual
uint32_t
rank_id
(
EnvironmentRole
role
)
=
0
;
//运行环境节点数
//
运行环境节点数
virtual
uint32_t
node_num
(
EnvironmentRole
role
)
=
0
;
//环境内主节点
//
环境内主节点
virtual
bool
is_master_node
(
EnvironmentRole
role
);
//For PS
virtual
paddle
::
ps
::
PSEnvironment
*
ps_environment
()
=
0
;
//环境定制化log
//
环境定制化log
template
<
class
...
ARGS
>
void
log
(
EnvironmentRole
role
,
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
char
*
fmt
,
ARGS
&&
...
args
)
{
print_log
(
role
,
type
,
level
,
paddle
::
string
::
format_string
(
fmt
,
args
...));
}
//多线程可调用接口 End
//
多线程可调用接口 End
//接口只允许在主线程调用 Start
//barrier 指定role的节点
//
接口只允许在主线程调用 Start
//
barrier 指定role的节点
virtual
void
barrier
(
EnvironmentRole
role
)
=
0
;
//bcast 广播
//
bcast 广播
virtual
void
bcast
(
paddle
::
framework
::
BinaryArchive
&
ar
,
int
root_id
,
EnvironmentRole
role
)
=
0
;
//接口只允许在主线程调用 End
//
接口只允许在主线程调用 End
protected:
virtual
void
print_log
(
EnvironmentRole
role
,
EnvironmentLogType
type
,
EnvironmentLogLevel
level
,
const
std
::
string
&
log_str
)
=
0
;
};
REGIST
ER
_REGISTERER
(
RuntimeEnvironment
);
REGIST_REGISTERER
(
RuntimeEnvironment
);
std
::
string
format_timestamp
(
time_t
time
,
const
char
*
format
);
inline
std
::
string
format_timestamp
(
time_t
time
,
const
std
::
string
&
format
)
{
...
...
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
浏览文件 @
f6bd8cfd
...
...
@@ -56,10 +56,10 @@ public:
return
0
;
}
};
REGIST
ER
_CLASS
(
DataParser
,
LineDataParser
);
REGIST_CLASS
(
DataParser
,
LineDataParser
);
int
DataReader
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
_parser
.
reset
(
CREATE_
CLASS
(
DataParser
,
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
()));
_parser
.
reset
(
CREATE_
INSTANCE
(
DataParser
,
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
()));
if
(
_parser
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to get parser: "
<<
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
();
return
-
1
;
...
...
@@ -85,7 +85,7 @@ public:
if
(
config
[
"file_system"
]
&&
config
[
"file_system"
][
"class"
])
{
_file_system
.
reset
(
CREATE_
CLASS
(
FileSystem
,
config
[
"file_system"
][
"class"
].
as
<
std
::
string
>
()));
CREATE_
INSTANCE
(
FileSystem
,
config
[
"file_system"
][
"class"
].
as
<
std
::
string
>
()));
if
(
_file_system
==
nullptr
||
_file_system
->
initialize
(
config
[
"file_system"
],
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to create class: "
...
...
@@ -95,7 +95,7 @@ public:
}
else
if
(
context
->
file_system
!=
nullptr
)
{
_file_system
=
context
->
file_system
;
}
else
{
_file_system
.
reset
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
_file_system
.
reset
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
if
(
_file_system
==
nullptr
||
_file_system
->
initialize
(
YAML
::
Load
(
""
),
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to init file system"
;
return
-
1
;
...
...
@@ -203,7 +203,7 @@ private:
std
::
string
_filename_prefix
;
std
::
shared_ptr
<
FileSystem
>
_file_system
;
};
REGIST
ER
_CLASS
(
DataReader
,
LineDataReader
);
REGIST_CLASS
(
DataReader
,
LineDataReader
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
f6bd8cfd
...
...
@@ -54,7 +54,7 @@ public:
virtual
int
parse
(
const
char
*
str
,
DataItem
&
data
)
const
=
0
;
virtual
int
parse_to_sample
(
const
DataItem
&
data
,
SampleInstance
&
instance
)
const
=
0
;
};
REGIST
ER
_REGISTERER
(
DataParser
);
REGIST_REGISTERER
(
DataParser
);
class
DataReader
{
public:
...
...
@@ -76,7 +76,7 @@ protected:
std
::
shared_ptr
<
DataParser
>
_parser
;
//数据格式转换
std
::
string
_pipeline_cmd
;
//将文件流,重定向到pipeline_cmd,再读入
};
REGIST
ER
_REGISTERER
(
DataReader
);
REGIST_REGISTERER
(
DataReader
);
}
//namespace feed
}
//namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
浏览文件 @
f6bd8cfd
...
...
@@ -32,7 +32,7 @@ int DatasetContainer::initialize(
_data_split_interval
=
config
[
"data_spit_interval"
].
as
<
int
>
();
_data_path_formater
=
config
[
"data_path_formater"
].
as
<
std
::
string
>
();
std
::
string
data_reader_class
=
config
[
"data_reader"
].
as
<
std
::
string
>
();
DataReader
*
data_reader
=
CREATE_
CLASS
(
DataReader
,
data_reader_class
);
DataReader
*
data_reader
=
CREATE_
INSTANCE
(
DataReader
,
data_reader_class
);
_data_reader
.
reset
(
data_reader
);
return
_data_reader
->
initialize
(
config
,
context
);
}
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
浏览文件 @
f6bd8cfd
...
...
@@ -121,7 +121,7 @@ protected:
std
::
unique_ptr
<
Context
>
_context
;
};
REGIST
ER
_CLASS
(
Executor
,
SimpleExecutor
);
REGIST_CLASS
(
Executor
,
SimpleExecutor
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.h
浏览文件 @
f6bd8cfd
...
...
@@ -40,7 +40,7 @@ public:
protected:
::
paddle
::
framework
::
Scope
_scope
;
};
REGIST
ER
_REGISTERER
(
Executor
);
REGIST_REGISTERER
(
Executor
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
浏览文件 @
f6bd8cfd
...
...
@@ -31,7 +31,7 @@ public:
_file_system
.
clear
();
if
(
config
&&
config
[
"file_systems"
]
&&
config
[
"file_systems"
].
Type
()
==
YAML
::
NodeType
::
Map
)
{
for
(
auto
&
prefix_fs
:
config
[
"file_systems"
])
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
CLASS
(
FileSystem
,
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
)));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
INSTANCE
(
FileSystem
,
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
)));
if
(
fs
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to create class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
return
-
1
;
...
...
@@ -44,7 +44,7 @@ public:
}
}
if
(
_file_system
.
find
(
"default"
)
==
_file_system
.
end
())
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
if
(
fs
==
nullptr
||
fs
->
initialize
(
YAML
::
Load
(
""
),
context
)
!=
0
)
{
return
-
1
;
}
...
...
@@ -122,7 +122,7 @@ public:
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
FileSystem
>>
_file_system
;
};
REGIST
ER
_CLASS
(
FileSystem
,
AutoFileSystem
);
REGIST_CLASS
(
FileSystem
,
AutoFileSystem
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/io/file_system.h
浏览文件 @
f6bd8cfd
...
...
@@ -52,7 +52,7 @@ public:
protected:
int
_err_no
=
0
;
};
REGIST
ER
_REGISTERER
(
FileSystem
);
REGIST_REGISTERER
(
FileSystem
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
浏览文件 @
f6bd8cfd
...
...
@@ -203,7 +203,7 @@ private:
std
::
string
_hdfs_command
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
_ugi
;
};
REGIST
ER
_CLASS
(
FileSystem
,
HadoopFileSystem
);
REGIST_CLASS
(
FileSystem
,
HadoopFileSystem
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc
浏览文件 @
f6bd8cfd
...
...
@@ -129,7 +129,7 @@ public:
private:
size_t
_buffer_size
=
0
;
};
REGIST
ER
_CLASS
(
FileSystem
,
LocalFileSystem
);
REGIST_CLASS
(
FileSystem
,
LocalFileSystem
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/main.cc
浏览文件 @
f6bd8cfd
#include <time.h>
#include <fstream>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -21,27 +21,56 @@ int main(int argc, char* argv[]) {
//load trainer config
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
//environment
auto
&
config
=
trainer_context_ptr
->
trainer_config
;
std
::
string
env_class
=
config
[
"environment"
][
"environment_class"
].
as
<
std
::
string
>
();
trainer_context_ptr
->
environment
.
reset
(
CREATE_INSTANCE
(
RuntimeEnvironment
,
env_class
));
if
(
trainer_context_ptr
->
environment
->
initialize
(
config
[
"environment"
])
!=
0
)
{
return
-
1
;
}
EnvironmentRole
role
;
auto
*
environment
=
trainer_context_ptr
->
environment
.
get
();
environment
->
wireup
();
if
(
environment
->
rank_id
(
EnvironmentRole
::
ALL
)
%
2
==
0
)
{
role
=
EnvironmentRole
::
WORKER
;
}
else
{
role
=
EnvironmentRole
::
PSERVER
;
}
environment
->
set_role
(
role
);
trainer_context_ptr
->
pslib
.
reset
(
new
PSlib
());
std
::
string
ps_config
=
config
[
"environment"
][
"ps"
].
as
<
std
::
string
>
();
trainer_context_ptr
->
pslib
->
initialize
(
ps_config
,
environment
,
role
);
//VLOG(3) << "Node Start With Role:" << role;
std
::
vector
<
std
::
string
>
process_name_list
=
{
"InitEnvProcess"
,
"LearnerProcess"
};
for
(
const
auto
&
process_name
:
process_name_list
)
{
Process
*
process
=
CREATE_CLASS
(
Process
,
process_name
);
if
(
process
==
NULL
)
{
VLOG
(
1
)
<<
"Process:"
<<
process_name
<<
" does not exist"
;
return
-
1
;
switch
(
role
)
{
case
EnvironmentRole
::
WORKER
:
for
(
const
auto
&
process_name
:
process_name_list
)
{
Process
*
process
=
CREATE_INSTANCE
(
Process
,
process_name
);
if
(
process
==
NULL
)
{
VLOG
(
1
)
<<
"Process:"
<<
process_name
<<
" does not exist"
;
return
-
1
;
}
if
(
process
->
initialize
(
trainer_context_ptr
)
!=
0
)
{
VLOG
(
1
)
<<
"Process:"
<<
process_name
<<
" initialize failed"
;
return
-
1
;
}
trainer_context_ptr
->
process_list
.
push_back
(
std
::
shared_ptr
<
Process
>
(
process
));
}
for
(
auto
&
process
:
trainer_context_ptr
->
process_list
)
{
process
->
run
();
}
if
(
process
->
initialize
(
trainer_context_ptr
)
!=
0
)
{
VLOG
(
1
)
<<
"Process:"
<<
process_name
<<
" initialize failed"
;
return
-
1
;
break
;
case
EnvironmentRole
::
PSERVER
:
//wait server done
while
(
true
)
{
sleep
(
10000
);
}
trainer_context_ptr
->
process_list
.
push_back
(
std
::
shared_ptr
<
Process
>
(
process
));
}
for
(
auto
&
process
:
trainer_context_ptr
->
process_list
)
{
process
->
run
();
break
;
}
return
0
;
...
...
paddle/fluid/train/custom_trainer/feed/monitor/monitor.h
浏览文件 @
f6bd8cfd
...
...
@@ -39,7 +39,7 @@ protected:
std
::
string
_name
;
};
REGIST
ER
_REGISTERER
(
Monitor
);
REGIST_REGISTERER
(
Monitor
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
浏览文件 @
f6bd8cfd
...
...
@@ -20,22 +20,16 @@ int InitEnvProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
context_ptr
->
cpu_place
=
paddle
::
platform
::
CPUPlace
();
YAML
::
Node
config
=
_context_ptr
->
trainer_config
;
//environment
std
::
string
env_class
=
config
[
"environment"
][
"environment_class"
].
as
<
std
::
string
>
();
context_ptr
->
environment
.
reset
(
CREATE_CLASS
(
RuntimeEnvironment
,
env_class
));
if
(
context_ptr
->
environment
->
initialize
(
config
[
"environment"
])
!=
0
)
{
return
-
1
;
}
//file_system
context_ptr
->
file_system
.
reset
(
CREATE_
CLASS
(
FileSystem
,
"AutoFileSystem"
));
context_ptr
->
file_system
.
reset
(
CREATE_
INSTANCE
(
FileSystem
,
"AutoFileSystem"
));
if
(
context_ptr
->
file_system
->
initialize
(
config
[
"io"
],
context_ptr
)
!=
0
)
{
return
-
1
;
}
//epoch
std
::
string
epoch_class
=
config
[
"epoch"
][
"epoch_class"
].
as
<
std
::
string
>
();
context_ptr
->
epoch_accessor
.
reset
(
CREATE_
CLASS
(
EpochAccessor
,
epoch_class
));
context_ptr
->
epoch_accessor
.
reset
(
CREATE_
INSTANCE
(
EpochAccessor
,
epoch_class
));
if
(
context_ptr
->
epoch_accessor
->
initialize
(
config
[
"epoch"
],
context_ptr
)
!=
0
)
{
return
-
1
;
}
...
...
@@ -55,10 +49,12 @@ int InitEnvProcess::run() {
VLOG
(
3
)
<<
"Trainer Resume From epoch:"
<<
epoch_accessor
->
current_epoch_id
();
auto
next_epoch_id
=
epoch_accessor
->
next_epoch_id
(
epoch_accessor
->
current_epoch_id
());
_context_ptr
->
dataset
->
pre_detect_data
(
next_epoch_id
);
//step 1. psserver init
//step2. psserver load
VLOG
(
3
)
<<
"Psserver Start Success"
;
if
(
epoch_accessor
->
checkpoint_path
().
size
()
>
0
)
{
//Load Model
}
else
{
//Random Init Model
}
//context_ptr->pslib_client()->load_model();
VLOG
(
3
)
<<
"Psserver Load Model Success"
;
return
0
;
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
f6bd8cfd
...
...
@@ -25,7 +25,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
_threads_executor
[
i
].
resize
(
_executor_num
);
for
(
int
e
=
0
;
e
<
_executor_num
;
++
e
)
{
auto
e_class
=
config
[
"executor"
][
e
][
"class"
].
as
<
std
::
string
>
();
auto
*
e_ptr
=
CREATE_
CLASS
(
Executor
,
e_class
);
auto
*
e_ptr
=
CREATE_
INSTANCE
(
Executor
,
e_class
);
_threads_executor
[
i
][
e
].
reset
(
e_ptr
);
if
(
e_ptr
->
initialize
(
config
[
"executor"
][
e
],
context_ptr
)
!=
0
)
{
ret
=
-
1
;
...
...
@@ -84,53 +84,59 @@ int LearnerProcess::run() {
while
(
true
)
{
epoch_accessor
->
next_epoch
();
bool
already_dump_inference_model
=
false
;
epoch_id
=
epoch_accessor
->
current_epoch_id
();
std
::
string
epoch_log_title
=
paddle
::
string
::
format_string
(
"train epoch_id:%d label:%s"
,
epoch_id
,
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//Step1. 等待样本ready
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Start %s, wait data ready"
,
epoch_log_title
.
c_str
());
while
(
dataset
->
epoch_data_status
(
epoch_id
)
!=
DatasetStatus
::
Ready
)
{
sleep
(
30
);
dataset
->
pre_detect_data
(
epoch_id
);
{
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data not ready, wait 30s"
,
epoch_log_title
.
c_str
());
}
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data is ready, start traning"
,
epoch_log_title
.
c_str
());
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
"Start %s, wait data ready"
,
epoch_log_title
.
c_str
());
while
(
dataset
->
epoch_data_status
(
epoch_id
)
!=
DatasetStatus
::
Ready
)
{
sleep
(
30
);
dataset
->
pre_detect_data
(
epoch_id
);
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data not ready, wait 30s"
,
epoch_log_title
.
c_str
());
}
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"%s, data is ready, start traning"
,
epoch_log_title
.
c_str
());
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
}
//Step2. 运行训练网络
bool
already_dump_inference_model
=
false
;
for
(
int
i
=
0
;
i
<
_executor_num
;
++
i
)
{
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>>
train_threads
(
_train_thread_num
);
for
(
int
thread_id
=
0
;
thread_id
<
_train_thread_num
;
++
thread_id
)
{
train_threads
[
i
].
reset
(
new
std
::
thread
([
this
](
int
exe_idx
,
int
thread_idx
)
{
auto
*
executor
=
_threads_executor
[
thread_idx
][
exe_idx
].
get
();
run_executor
(
executor
);
},
i
,
thread_id
));
}
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
train_threads
[
i
]
->
join
();
{
for
(
int
i
=
0
;
i
<
_executor_num
;
++
i
)
{
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>>
train_threads
(
_train_thread_num
);
for
(
int
thread_id
=
0
;
thread_id
<
_train_thread_num
;
++
thread_id
)
{
train_threads
[
i
].
reset
(
new
std
::
thread
([
this
](
int
exe_idx
,
int
thread_idx
)
{
auto
*
executor
=
_threads_executor
[
thread_idx
][
exe_idx
].
get
();
run_executor
(
executor
);
},
i
,
thread_id
));
}
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
train_threads
[
i
]
->
join
();
}
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
if
(
_threads_executor
[
0
][
i
]
->
is_dump_all_model
())
{
already_dump_inference_model
=
true
;
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
}
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
}
if
(
_threads_executor
[
0
][
i
]
->
is_dump_all_model
())
{
//Step3. Dump Model For Delta&&Checkpoint
{
if
(
!
already_dump_inference_model
)
{
already_dump_inference_model
=
true
;
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
}
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpoint
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
}
//Step3. Dump Model For Delta&&Checkpoint
if
(
!
already_dump_inference_model
)
{
already_dump_inference_model
=
true
;
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceDelta
);
}
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpoint
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
//Step4. Output Monitor && RunStatus
//TODO
}
...
...
paddle/fluid/train/custom_trainer/feed/process/process.cc
浏览文件 @
f6bd8cfd
...
...
@@ -5,8 +5,8 @@
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
REGIST
ER
_CLASS
(
Process
,
InitEnvProcess
);
REGIST
ER
_CLASS
(
Process
,
LearnerProcess
);
REGIST_CLASS
(
Process
,
InitEnvProcess
);
REGIST_CLASS
(
Process
,
LearnerProcess
);
int
Process
::
run
()
{
return
0
;
}
...
...
paddle/fluid/train/custom_trainer/feed/process/process.h
浏览文件 @
f6bd8cfd
...
...
@@ -18,7 +18,7 @@ public:
protected:
TrainerContext
*
_context_ptr
=
NULL
;
};
REGIST
ER
_REGISTERER
(
Process
);
REGIST_REGISTERER
(
Process
);
}
// namespace feed
}
// namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
f6bd8cfd
...
...
@@ -4,6 +4,7 @@
#include <vector>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
...
...
@@ -38,6 +39,7 @@ public:
YAML
::
Node
trainer_config
;
paddle
::
platform
::
CPUPlace
cpu_place
;
std
::
shared_ptr
<
PSlib
>
pslib
;
std
::
shared_ptr
<
Dataset
>
dataset
;
//训练样本
std
::
shared_ptr
<
FileSystem
>
file_system
;
//文件操作辅助类
std
::
vector
<
TableMeta
>
params_table_list
;
//参数表
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
浏览文件 @
f6bd8cfd
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
@@ -36,7 +37,7 @@ const char test_data_dir[] = "test_data";
class
DataReaderTest
:
public
testing
::
Test
{
public:
static
void
SetUpTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
...
...
@@ -56,14 +57,14 @@ public:
}
static
void
TearDownTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
fs
->
remove
(
test_data_dir
);
}
virtual
void
SetUp
()
{
thread_num
=
omp_get_max_threads
();
omp_set_num_threads
(
1
);
fs
.
reset
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
.
reset
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
context_ptr
.
reset
(
new
TrainerContext
());
}
...
...
@@ -79,7 +80,7 @@ public:
};
TEST_F
(
DataReaderTest
,
LineDataParser
)
{
std
::
unique_ptr
<
DataParser
>
data_parser
(
CREATE_
CLASS
(
DataParser
,
"LineDataParser"
));
std
::
unique_ptr
<
DataParser
>
data_parser
(
CREATE_
INSTANCE
(
DataParser
,
"LineDataParser"
));
ASSERT_NE
(
nullptr
,
data_parser
);
auto
config
=
YAML
::
Load
(
""
);
...
...
@@ -108,7 +109,7 @@ TEST_F(DataReaderTest, LineDataParser) {
}
TEST_F
(
DataReaderTest
,
LineDataReader
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
CLASS
(
DataReader
,
"LineDataReader"
));
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
INSTANCE
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
...
...
@@ -161,7 +162,7 @@ TEST_F(DataReaderTest, LineDataReader) {
}
TEST_F
(
DataReaderTest
,
LineDataReader_filename_prefix
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
CLASS
(
DataReader
,
"LineDataReader"
));
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
INSTANCE
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
...
...
@@ -196,7 +197,7 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
}
TEST_F
(
DataReaderTest
,
LineDataReader_FileSystem
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
CLASS
(
DataReader
,
"LineDataReader"
));
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
INSTANCE
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc
浏览文件 @
f6bd8cfd
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
@@ -37,7 +38,7 @@ const char test_data_dir[] = "test_data";
class
DataReaderOmpTest
:
public
testing
::
Test
{
public:
static
void
SetUpTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
std_items
.
clear
();
...
...
@@ -61,14 +62,14 @@ public:
}
static
void
TearDownTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
fs
->
remove
(
test_data_dir
);
}
virtual
void
SetUp
()
{
thread_num
=
omp_get_max_threads
();
omp_set_num_threads
(
1
);
fs
.
reset
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
.
reset
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
context_ptr
.
reset
(
new
TrainerContext
());
}
...
...
@@ -111,7 +112,7 @@ std::vector<DataItem> DataReaderOmpTest::std_items;
std
::
vector
<
DataItem
>
DataReaderOmpTest
::
sorted_std_items
;
TEST_F
(
DataReaderOmpTest
,
LineDataReaderSingleThread
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
CLASS
(
DataReader
,
"LineDataReader"
));
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
INSTANCE
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
...
...
@@ -148,7 +149,7 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
}
TEST_F
(
DataReaderOmpTest
,
LineDataReaderMuiltThread
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
CLASS
(
DataReader
,
"LineDataReader"
));
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_
INSTANCE
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
浏览文件 @
f6bd8cfd
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <fstream>
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
@@ -38,7 +39,7 @@ class SimpleExecutorTest : public testing::Test
public:
static
void
SetUpTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
...
...
@@ -70,7 +71,7 @@ public:
static
void
TearDownTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
CLASS
(
FileSystem
,
"LocalFileSystem"
));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_
INSTANCE
(
FileSystem
,
"LocalFileSystem"
));
fs
->
remove
(
test_data_dir
);
}
...
...
@@ -88,7 +89,7 @@ public:
};
TEST_F
(
SimpleExecutorTest
,
initialize
)
{
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_
CLASS
(
Executor
,
"SimpleExecutor"
));
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_
INSTANCE
(
Executor
,
"SimpleExecutor"
));
ASSERT_NE
(
nullptr
,
executor
);
YAML
::
Node
config
=
YAML
::
Load
(
"[1, 2, 3]"
);
ASSERT_NE
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
...
...
@@ -99,7 +100,7 @@ TEST_F(SimpleExecutorTest, initialize) {
}
TEST_F
(
SimpleExecutorTest
,
run
)
{
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_
CLASS
(
Executor
,
"SimpleExecutor"
));
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_
INSTANCE
(
Executor
,
"SimpleExecutor"
));
ASSERT_NE
(
nullptr
,
executor
);
auto
config
=
YAML
::
Load
(
string
::
format_string
(
"{thread_num: 2, startup_program: %s, main_program: %s}"
,
startup_program_path
,
main_program_path
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录