Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b415ec27
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b415ec27
编写于
3月 09, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make Dataset* as an argument
上级
dd67ad08
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
134 addition
and
49 deletion
+134
-49
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+1
-1
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+1
-1
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+5
-12
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+4
-7
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+2
-7
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+5
-20
python/paddle/fluid/distributed/fleet.py
python/paddle/fluid/distributed/fleet.py
+63
-0
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+20
-0
python/paddle/fluid/trainer_factory.py
python/paddle/fluid/trainer_factory.py
+32
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
b415ec27
...
@@ -30,7 +30,7 @@ add_subdirectory(io)
...
@@ -30,7 +30,7 @@ add_subdirectory(io)
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
data_feed_proto SRCS data_feed.proto
)
proto_library
(
data_feed_proto SRCS data_feed.proto
)
proto_library
(
async_executor_proto SRCS data_feed.proto
)
proto_library
(
async_executor_proto SRCS data_feed.proto
)
proto_library
(
trainer_desc_proto SRCS trainer_desc.proto
)
proto_library
(
trainer_desc_proto SRCS trainer_desc.proto
data_feed.proto
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost enforce
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost enforce
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
b415ec27
...
@@ -52,7 +52,7 @@ void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
...
@@ -52,7 +52,7 @@ void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
data_feed_desc_str
,
&
data_feed_desc_
);
data_feed_desc_str
,
&
data_feed_desc_
);
}
}
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
Dataset
::
GetReaders
()
{
Dataset
::
GetReaders
()
{
return
readers_
;
return
readers_
;
}
}
...
...
paddle/fluid/framework/data_set.h
浏览文件 @
b415ec27
...
@@ -43,7 +43,7 @@ class Dataset {
...
@@ -43,7 +43,7 @@ class Dataset {
return
data_feed_desc_
;
return
data_feed_desc_
;
}
}
virtual
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
virtual
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
GetReaders
();
GetReaders
();
virtual
void
LoadIntoMemory
();
virtual
void
LoadIntoMemory
();
virtual
void
LocalShuffle
();
virtual
void
LocalShuffle
();
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
b415ec27
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer.h"
...
@@ -25,26 +26,18 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -25,26 +26,18 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset
*
data_set
)
{
Dataset
*
data_set
)
{
thread_num_
=
trainer_desc
.
thread_num
();
thread_num_
=
trainer_desc
.
thread_num
();
workers_
.
resize
(
thread_num_
);
workers_
.
resize
(
thread_num_
);
readers_
.
resize
(
thread_num_
);
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
data_set
->
GetReaders
();
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
trainer_desc
.
device_worker_name
());
readers_
[
i
]
=
DataFeedFactory
::
CreateDataFeed
(
trainer_desc
.
data_desc
().
name
());
workers_
[
i
]
->
SetDeviceIndex
(
i
);
workers_
[
i
]
->
SetDeviceIndex
(
i
);
readers_
[
i
]
->
Init
(
trainer_desc
.
data_desc
());
workers_
[
i
]
->
SetDataFeed
(
readers
[
i
]);
workers_
[
i
]
->
SetDataFeed
(
readers_
[
i
]);
workers_
[
i
]
->
Initialize
(
trainer_desc
);
workers_
[
i
]
->
Initialize
(
trainer_desc
);
}
}
std
::
vector
<
std
::
string
>
filelist_vec
;
for
(
unsigned
i
=
0
;
i
<
trainer_desc
.
filelist_size
();
++
i
)
{
filelist_vec
.
push_back
(
trainer_desc
.
filelist
(
i
));
}
readers_
[
0
]
->
SetFileList
(
filelist_vec
);
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
pull_dense_worker_
=
PullDenseWorker
::
GetInstance
();
pull_dense_worker_
=
PullDenseWorker
::
GetInstance
();
pull_dense_worker_
->
Initialize
(
trainer_desc
);
pull_dense_worker_
->
Initialize
(
trainer_desc
);
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
b415ec27
...
@@ -116,10 +116,9 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
...
@@ -116,10 +116,9 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
}
}
}
void
Executor
::
RunFromDataset
(
const
ProgramDesc
&
main_program
,
void
Executor
::
RunFromDataset
(
const
ProgramDesc
&
main_program
,
Scope
*
scope
,
Dataset
*
dataset
,
Dataset
*
dataset
,
const
std
::
string
&
trainer_desc_str
,
const
std
::
string
&
trainer_desc_str
)
{
const
bool
debug
)
{
VLOG
(
3
)
<<
"Start to RunFromDataset in executor"
;
VLOG
(
3
)
<<
"Start to RunFromDataset in executor"
;
TrainerDesc
trainer_desc
;
TrainerDesc
trainer_desc
;
google
::
protobuf
::
TextFormat
::
ParseFromString
(
trainer_desc_str
,
google
::
protobuf
::
TextFormat
::
ParseFromString
(
trainer_desc_str
,
...
@@ -132,9 +131,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program,
...
@@ -132,9 +131,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program,
VLOG
(
3
)
<<
"Going to initialize trainer"
;
VLOG
(
3
)
<<
"Going to initialize trainer"
;
trainer
->
Initialize
(
trainer_desc
,
dataset
);
trainer
->
Initialize
(
trainer_desc
,
dataset
);
VLOG
(
3
)
<<
"Set root scope here"
;
VLOG
(
3
)
<<
"Set root scope here"
;
trainer
->
SetScope
(
root_scope_
);
trainer
->
SetScope
(
scope
);
VLOG
(
3
)
<<
"Going to set debug"
;
trainer
->
SetDebug
(
debug
);
// prepare training environment and helper environment
// prepare training environment and helper environment
VLOG
(
3
)
<<
"Try to init train environment"
;
VLOG
(
3
)
<<
"Try to init train environment"
;
trainer
->
InitTrainerEnv
(
main_program
,
place_
);
trainer
->
InitTrainerEnv
(
main_program
,
place_
);
...
@@ -146,7 +143,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program,
...
@@ -146,7 +143,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program,
VLOG
(
3
)
<<
"Trainer going to finalize"
;
VLOG
(
3
)
<<
"Trainer going to finalize"
;
trainer
->
Finalize
();
trainer
->
Finalize
();
VLOG
(
3
)
<<
"Drop current scope kids"
;
VLOG
(
3
)
<<
"Drop current scope kids"
;
root_scope_
->
DropKids
();
scope
->
DropKids
();
return
;
return
;
}
}
...
...
paddle/fluid/framework/executor.h
浏览文件 @
b415ec27
...
@@ -114,16 +114,11 @@ class Executor {
...
@@ -114,16 +114,11 @@ class Executor {
void
EnableMKLDNN
(
const
ProgramDesc
&
program
);
void
EnableMKLDNN
(
const
ProgramDesc
&
program
);
void
RunFromDataset
(
const
ProgramDesc
&
main_program
,
Dataset
*
dataset
,
void
RunFromDataset
(
const
ProgramDesc
&
main_program
,
Scope
*
scope
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
);
Dataset
*
dataset
,
const
std
::
string
&
trainer_desc_str
);
public:
std
::
shared_ptr
<
paddle
::
framework
::
FleetWrapper
>
fleet_ptr_
;
Scope
*
root_scope_
;
private:
private:
const
platform
::
Place
place_
;
const
platform
::
Place
place_
;
int
actual_thread_num_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
b415ec27
...
@@ -26,31 +26,16 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -26,31 +26,16 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_
=
trainer_desc
.
thread_num
();
thread_num_
=
trainer_desc
.
thread_num
();
// get filelist from trainer_desc here
// get filelist from trainer_desc here
workers_
.
resize
(
thread_num_
);
workers_
.
resize
(
thread_num_
);
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
/*
dataset
->
GetReaders
();
if (NULL == dataset) {
readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
readers_[i]->Init(trainer_desc.data_desc());
}
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
readers_[0]->SetFileList(filelist_vec);
} else {
// readers_ = dataset.get_readers(); ?
}
*/
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
trainer_desc
.
device_worker_name
());
workers_
[
i
]
->
SetDeviceIndex
(
i
);
workers_
[
i
]
->
SetDeviceIndex
(
i
);
workers_
[
i
]
->
SetDataFeed
(
readers
_
[
i
]);
workers_
[
i
]
->
SetDataFeed
(
readers
[
i
]);
}
}
// set debug here
}
}
// call only after all resources are set in current trainer
// call only after all resources are set in current trainer
...
...
python/paddle/fluid/distributed/fleet.py
0 → 100644
浏览文件 @
b415ec27
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from
..
import
core
__all__
=
[
'Fleet'
]
class
Fleet
(
object
):
"""
"""
def
__init__
(
self
):
self
.
instance_
=
ps_instance
.
PaddlePSInstance
()
self
.
fleet_
=
core
.
FleetWrapper
()
def
stop
(
self
):
self
.
instance_
.
barrier_worker
()
if
self
.
instance
.
is_first_worker
():
self
.
fleet_
.
stop_server
()
self
.
instance_
.
barrier_worker
()
self
.
instance_
.
barrier_all
()
self
.
instance
.
finalize
()
def
init_pserver
(
self
,
dist_desc
):
self
.
dist_desc_str_
=
text_format
.
MessageToString
(
dist_desc
)
self
.
dist_desc
=
dist_desc
self
.
fleet_
.
init_server
(
self
.
dist_desc_str_
)
ip
=
self
.
fleet_
.
start_server
()
self
.
instance_
.
set_ip
(
ip
)
self
.
instance
.
barrier_all
()
ips
=
self
.
instance
.
gather_ips
()
self
.
fleet
.
gather_servers
(
ips
,
self
.
instance_
.
get_node_cnt
())
self
.
instance_
.
barrier_all
()
def
init_worker
(
self
,
dist_desc
):
self
.
dist_desc_str_
=
text_format
.
MessageToString
(
dist_desc
)
self
.
dist_desc_
=
dist_desc
self
.
instance_
.
barrier_all
()
ips
=
self
.
instance
.
gather_ips
()
self
.
fleet_
.
init_worker
(
self
.
dist_desc_str_
,
ips
,
self
.
instance_
.
get_node_cnt
(),
self
.
instance
.
_rankid
)
self
.
instance
.
barrier_worker
()
def
init_pserver_model
(
self
):
if
self
.
instance_
.
is_first_worker
():
self
.
fleet_
.
init_model
()
self
.
instance_
.
barrier_worker
()
def
save_pserver_model
(
self
,
save_path
):
self
.
fleet_
.
save_model
(
save_path
)
python/paddle/fluid/executor.py
浏览文件 @
b415ec27
...
@@ -610,3 +610,23 @@ class Executor(object):
...
@@ -610,3 +610,23 @@ class Executor(object):
def
_run_inference
(
self
,
exe
,
feed
):
def
_run_inference
(
self
,
exe
,
feed
):
return
exe
.
run
(
feed
)
return
exe
.
run
(
feed
)
def
run_from_dataset
(
self
,
program
=
None
,
dataset
=
None
,
fetch_list
=
None
,
scope
=
None
,
opt_info
=
None
):
if
scope
is
None
:
scope
=
global_scope
()
if
fetch_list
is
None
:
fetch_list
=
[]
compiled
=
isinstance
(
program
,
compiler
.
CompiledProgram
)
if
not
compiled
:
trainer
=
TrainerFactory
().
create_trainer
(
opt_info
)
self
.
_default_executor
.
run_from_dataset
(
program_desc
,
trainer
.
_desc
())
else
:
# For compiled program, more runtime should be implemented
print
(
"run_from_dataset current does not support compiled program"
", we will support this later"
,
sys
.
stderr
)
python/paddle/fluid/trainer.py
→
python/paddle/fluid/trainer
_factory
.py
浏览文件 @
b415ec27
# Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 201
9
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,5 +12,21 @@
...
@@ -12,5 +12,21 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# NOTE: Trainer is moved into fluid.contrib.trainer.
__all__
=
[
"TrainerFactory"
]
__all__
=
[]
class
TrainerFactory
(
object
):
def
__init__
(
self
):
pass
def
create_trainer
(
self
,
opt_info
=
None
):
if
opt_info
==
None
:
return
MultiTrainer
()
else
:
if
opt_info
[
"optimizer"
]
==
"DownpourSGD"
:
trainer
=
DistMultiTrainer
()
trainer
.
gen_trainer_desc
(
fleet_desc
=
opt_info
[
"fleet"
],
worker
=
"downpour"
)
return
trainer
else
:
print
(
"Currently only support DownpourSGD"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录