Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
24863897
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
24863897
编写于
3月 08, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add RunFromDataset in executor
上级
e36bbcc8
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
94 addition
and
78 deletion
+94
-78
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+11
-28
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+1
-10
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+33
-30
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+2
-1
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+37
-4
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+3
-1
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+7
-4
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
24863897
...
@@ -28,7 +28,7 @@ add_subdirectory(common)
...
@@ -28,7 +28,7 @@ add_subdirectory(common)
add_subdirectory
(
io
)
add_subdirectory
(
io
)
#ddim lib
#ddim lib
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
async_executor
_proto SRCS data_feed.proto
)
proto_library
(
data_feed
_proto SRCS data_feed.proto
)
proto_library
(
trainer_desc_proto SRCS trainer_desc.proto
)
proto_library
(
trainer_desc_proto SRCS trainer_desc.proto
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost enforce
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost enforce
)
...
@@ -175,15 +175,11 @@ cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc o
...
@@ -175,15 +175,11 @@ cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc o
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper
${
NGRAPH_EXE_DEPS
}
)
lod_rank_table feed_fetch_method sendrecvop_rpc
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper
${
NGRAPH_EXE_DEPS
}
trainer_library
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
else
()
if
(
WITH_NGRAPH
)
cc_library
(
executor SRCS executor.cc multi_trainer.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper
${
NGRAPH_EXE_DEPS
}
timer
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper
)
else
(
WITH_NGRAPH
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper
)
endif
(
WITH_NGRAPH
)
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
endif
()
endif
()
...
@@ -194,28 +190,15 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
...
@@ -194,28 +190,15 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper
)
fast_threaded_ssa_graph_executor variable_helper
)
if
(
WITH_PSLIB
)
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer
)
else
()
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
data_set.cc DEPS op_registry device_context scope framework_proto
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor
_proto
feed_fetch_method graph_to_program_pass data_feed
_proto
variable_helper timer
)
variable_helper timer
)
endif
(
WITH_PSLIB
)
cc_test
(
data_feed_test SRCS data_feed_test.cc DEPS async_executor
)
cc_test
(
data_feed_test SRCS data_feed_test.cc DEPS async_executor
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
...
...
paddle/fluid/framework/async_executor.cc
浏览文件 @
24863897
...
@@ -154,14 +154,5 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...
@@ -154,14 +154,5 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return
;
return
;
}
}
// todo RunFromDataset
}
// end namespace framework
void
AsyncExecutor
::
RunFromDataset
(
const
ProgramDesc
&
main_program
,
Dataset
*
data_set
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
)
{
}
}
// einit_modelnd namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/data_feed.cc
浏览文件 @
24863897
...
@@ -135,9 +135,7 @@ int PrivateQueueDataFeed<T>::Next() {
...
@@ -135,9 +135,7 @@ int PrivateQueueDataFeed<T>::Next() {
return
batch_size_
;
return
batch_size_
;
}
}
#ifdef _WIN32
template
class
PrivateQueueDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
template
class
PrivateQueueDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
#endif
template
<
typename
T
>
template
<
typename
T
>
InMemoryDataFeed
<
T
>::
InMemoryDataFeed
()
{
InMemoryDataFeed
<
T
>::
InMemoryDataFeed
()
{
...
@@ -150,7 +148,7 @@ template <typename T>
...
@@ -150,7 +148,7 @@ template <typename T>
bool
InMemoryDataFeed
<
T
>::
Start
()
{
bool
InMemoryDataFeed
<
T
>::
Start
()
{
DataFeed
::
CheckSetFileList
();
DataFeed
::
CheckSetFileList
();
if
(
memory_data_
.
size
()
!=
0
)
{
if
(
memory_data_
.
size
()
!=
0
)
{
CHECK
(
cur_channel_
==
0
);
CHECK
_EQ
(
cur_channel_
,
0
);
shuffled_ins_
->
Extend
(
std
::
move
(
memory_data_
));
shuffled_ins_
->
Extend
(
std
::
move
(
memory_data_
));
std
::
vector
<
T
>
().
swap
(
memory_data_
);
std
::
vector
<
T
>
().
swap
(
memory_data_
);
}
}
...
@@ -205,11 +203,11 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
...
@@ -205,11 +203,11 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
std
::
string
filename
;
std
::
string
filename
;
while
(
DataFeed
::
PickOneFile
(
&
filename
))
{
while
(
DataFeed
::
PickOneFile
(
&
filename
))
{
int
err_no
=
0
;
int
err_no
=
0
;
PrivateQueueDataFeed
<
T
>::
fp_
=
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
fp_
=
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
__fsetlocking
(
&*
PrivateQueueDataFeed
<
T
>::
fp_
,
FSETLOCKING_BYCALLER
);
__fsetlocking
(
&*
PrivateQueueDataFeed
<
T
>::
fp_
,
FSETLOCKING_BYCALLER
);
T
instance
;
T
instance
;
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
local_vec
.
push_back
(
instance
);
local_vec
.
push_back
(
instance
);
}
}
memory_data_
.
insert
(
memory_data_
.
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
memory_data_
.
insert
(
memory_data_
.
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
...
@@ -242,6 +240,8 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
...
@@ -242,6 +240,8 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
}
}
*/
*/
template
class
InMemoryDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
void
MultiSlotDataFeed
::
Init
(
void
MultiSlotDataFeed
::
Init
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
{
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
{
finish_init_
=
false
;
finish_init_
=
false
;
...
@@ -633,7 +633,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
...
@@ -633,7 +633,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
}
}
}
}
bool
MultiSlotInMemoryDataFeed
::
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>*
instance
)
{
bool
MultiSlotInMemoryDataFeed
::
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>*
instance
)
{
std
::
string
line
;
std
::
string
line
;
if
(
getline
(
file_
,
line
))
{
if
(
getline
(
file_
,
line
))
{
int
use_slots_num
=
use_slots_
.
size
();
int
use_slots_num
=
use_slots_
.
size
();
...
@@ -725,12 +726,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
...
@@ -725,12 +726,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
}
}
// todo serialize ins in global shuffle
// todo serialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
&
str
)
{
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
&
str
)
{
return
;
}
}
// todo deserialize ins in global shuffle
// todo deserialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>&
ins
,
const
std
::
string
&
str
)
{
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>&
ins
,
const
std
::
string
&
str
)
{
return
;
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
24863897
...
@@ -21,7 +21,8 @@ limitations under the License. */
...
@@ -21,7 +21,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
void
DistMultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
data_set
)
{
void
DistMultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
data_set
)
{
thread_num_
=
trainer_desc
.
thread_num
();
thread_num_
=
trainer_desc
.
thread_num
();
workers_
.
resize
(
thread_num_
);
workers_
.
resize
(
thread_num_
);
readers_
.
resize
(
thread_num_
);
readers_
.
resize
(
thread_num_
);
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
24863897
...
@@ -19,13 +19,16 @@ limitations under the License. */
...
@@ -19,13 +19,16 @@ limitations under the License. */
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include <utility>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
...
@@ -115,9 +118,39 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
...
@@ -115,9 +118,39 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
}
}
}
void
Executor
::
RunFromDataset
(
const
ProgramDesc
&
pdesc
,
const
Dataset
&
dataset
,
void
Executor
::
RunFromDataset
(
const
ProgramDesc
&
main_program
,
const
Dataset
&
dataset
,
const
std
::
string
&
trainer_desc_str
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
)
{}
const
bool
debug
)
{
VLOG
(
3
)
<<
"Start to RunFromDataset in executor"
;
TrainerDesc
trainer_desc
;
google
::
protobuf
::
TextFormat
::
ParseFromString
(
trainer_desc_str
,
&
trainer_desc
);
VLOG
(
3
)
<<
"Going to create trainer, trainer class is "
<<
trainer_desc
.
class_name
();
std
::
shared_ptr
<
TrainerBase
>
trainer
;
trainer
=
TrainerFactory
::
CreateTrainer
(
trainer_desc
.
class_name
());
// initialize trainer
VLOG
(
3
)
<<
"Going to initialize trainer"
;
trainer
->
Initialize
(
trainer_desc
,
dataset
);
VLOG
(
3
)
<<
"Set root scope here"
;
trainer
->
SetScope
(
root_scope_
);
VLOG
(
3
)
<<
"Going to set debug"
;
trainer
->
SetDebug
(
debug
);
// prepare training environment and helper environment
VLOG
(
3
)
<<
"Try to init train environment"
;
trainer
->
InitTrainerEnv
(
main_program
,
place_
);
VLOG
(
3
)
<<
"Try to init other environment"
;
trainer
->
InitOtherEnv
(
main_program
);
// training and finalize training
VLOG
(
3
)
<<
"Trainer starts to run"
;
trainer
->
Run
();
VLOG
(
3
)
<<
"Trainer going to finalize"
;
trainer
->
Finalize
();
VLOG
(
3
)
<<
"Drop current scope kids"
;
root_scope_
->
DropKids
();
return
;
}
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
,
bool
create_local_scope
,
bool
create_vars
,
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
24863897
...
@@ -22,11 +22,12 @@ namespace paddle {
...
@@ -22,11 +22,12 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
void
MultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
void
MultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
dataset
)
{
const
Dataset
&
dataset
)
{
thread_num_
=
trainer_desc
.
thread_num
();
thread_num_
=
trainer_desc
.
thread_num
();
// get filelist from trainer_desc here
// get filelist from trainer_desc here
workers_
.
resize
(
thread_num_
);
workers_
.
resize
(
thread_num_
);
/*
if (NULL == dataset) {
if (NULL == dataset) {
readers_.resize(thread_num_);
readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
for (int i = 0; i < thread_num_; ++i) {
...
@@ -42,6 +43,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -42,6 +43,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
} else {
} else {
// readers_ = dataset.get_readers(); ?
// readers_ = dataset.get_readers(); ?
}
}
*/
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
...
...
paddle/fluid/framework/trainer.h
浏览文件 @
24863897
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
...
@@ -29,7 +30,6 @@ limitations under the License. */
...
@@ -29,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/framework/data_set.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -41,7 +41,8 @@ class TrainerBase {
...
@@ -41,7 +41,8 @@ class TrainerBase {
// model memory are hosted in root_scope
// model memory are hosted in root_scope
void
SetScope
(
Scope
*
root_scope
);
void
SetScope
(
Scope
*
root_scope
);
void
SetDebug
(
const
bool
debug
)
{
debug_
=
debug
;
}
void
SetDebug
(
const
bool
debug
)
{
debug_
=
debug
;
}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
data_set
)
=
0
;
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
data_set
)
=
0
;
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
)
=
0
;
const
platform
::
Place
&
place
)
=
0
;
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
=
0
;
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
=
0
;
...
@@ -60,7 +61,8 @@ class MultiTrainer : public TrainerBase {
...
@@ -60,7 +61,8 @@ class MultiTrainer : public TrainerBase {
public:
public:
MultiTrainer
()
{}
MultiTrainer
()
{}
virtual
~
MultiTrainer
()
{}
virtual
~
MultiTrainer
()
{}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
data_set
);
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
data_set
);
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
);
const
platform
::
Place
&
place
);
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{}
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{}
...
@@ -78,7 +80,8 @@ class DistMultiTrainer : public MultiTrainer {
...
@@ -78,7 +80,8 @@ class DistMultiTrainer : public MultiTrainer {
public:
public:
DistMultiTrainer
()
{}
DistMultiTrainer
()
{}
virtual
~
DistMultiTrainer
()
{}
virtual
~
DistMultiTrainer
()
{}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
data_set
);
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
data_set
);
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
);
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
);
virtual
void
Finalize
();
virtual
void
Finalize
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录