Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b66f0074
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b66f0074
编写于
3月 10, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix data reading bugs in api, add VLOG(3) log for setup
上级
71aa307e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
46 addition
and
9 deletion
+46
-9
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+4
-0
paddle/fluid/framework/data_feed_factory.cc
paddle/fluid/framework/data_feed_factory.cc
+3
-0
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+11
-3
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+3
-2
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+2
-1
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+1
-0
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+5
-0
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+7
-2
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+10
-1
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
b66f0074
...
@@ -44,10 +44,14 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
...
@@ -44,10 +44,14 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
bool
DataFeed
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
)
{
bool
DataFeed
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
CheckInit
();
CheckInit
();
// Do not set finish_set_filelist_ flag,
// since a user may set file many times after init reader
/*
if (finish_set_filelist_) {
if (finish_set_filelist_) {
VLOG(3) << "info: you have set the filelist.";
VLOG(3) << "info: you have set the filelist.";
return false;
return false;
}
}
*/
PADDLE_ENFORCE
(
files
.
size
(),
"You have set an empty filelist."
);
PADDLE_ENFORCE
(
files
.
size
(),
"You have set an empty filelist."
);
filelist_
.
assign
(
files
.
begin
(),
files
.
end
());
filelist_
.
assign
(
files
.
begin
(),
files
.
end
());
file_idx_
=
0
;
file_idx_
=
0
;
...
...
paddle/fluid/framework/data_feed_factory.cc
浏览文件 @
b66f0074
...
@@ -54,6 +54,9 @@ std::string DataFeedFactory::DataFeedTypeList() {
...
@@ -54,6 +54,9 @@ std::string DataFeedFactory::DataFeedTypeList() {
std
::
shared_ptr
<
DataFeed
>
DataFeedFactory
::
CreateDataFeed
(
std
::
shared_ptr
<
DataFeed
>
DataFeedFactory
::
CreateDataFeed
(
std
::
string
data_feed_class
)
{
std
::
string
data_feed_class
)
{
if
(
g_data_feed_map
.
count
(
data_feed_class
)
<
1
)
{
if
(
g_data_feed_map
.
count
(
data_feed_class
)
<
1
)
{
LOG
(
WARNING
)
<<
"Your DataFeed "
<<
data_feed_class
<<
"is not supported currently"
;
LOG
(
WARNING
)
<<
"Supported DataFeed: "
<<
DataFeedTypeList
();
exit
(
-
1
);
exit
(
-
1
);
}
}
return
g_data_feed_map
[
data_feed_class
]();
return
g_data_feed_map
[
data_feed_class
]();
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
b66f0074
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -24,6 +24,7 @@ namespace framework {
...
@@ -24,6 +24,7 @@ namespace framework {
Dataset
::
Dataset
()
{
thread_num_
=
1
;
}
Dataset
::
Dataset
()
{
thread_num_
=
1
;
}
void
Dataset
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
void
Dataset
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
VLOG
(
3
)
<<
"filelist size: "
<<
filelist
.
size
();
filelist_
=
filelist
;
filelist_
=
filelist
;
int
file_cnt
=
filelist_
.
size
();
int
file_cnt
=
filelist_
.
size
();
if
(
thread_num_
>
file_cnt
)
{
if
(
thread_num_
>
file_cnt
)
{
...
@@ -34,6 +35,8 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) {
...
@@ -34,6 +35,8 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) {
}
}
}
}
// buggy here, a user should set filelist first before this function
// not user friendly
void
Dataset
::
SetThreadNum
(
int
thread_num
)
{
void
Dataset
::
SetThreadNum
(
int
thread_num
)
{
int
file_cnt
=
filelist_
.
size
();
int
file_cnt
=
filelist_
.
size
();
if
(
file_cnt
!=
0
&&
thread_num
>
file_cnt
)
{
if
(
file_cnt
!=
0
&&
thread_num
>
file_cnt
)
{
...
@@ -48,8 +51,8 @@ void Dataset::SetThreadNum(int thread_num) {
...
@@ -48,8 +51,8 @@ void Dataset::SetThreadNum(int thread_num) {
void
Dataset
::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
void
Dataset
::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
void
Dataset
::
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
{
void
Dataset
::
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
{
google
::
protobuf
::
TextFormat
::
ParseFromString
(
google
::
protobuf
::
TextFormat
::
ParseFromString
(
data_feed_desc_str
,
data_feed_desc_str
,
&
data_feed_desc_
);
&
data_feed_desc_
);
}
}
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
...
@@ -107,14 +110,19 @@ void Dataset::GlobalShuffle() {
...
@@ -107,14 +110,19 @@ void Dataset::GlobalShuffle() {
}
}
void
Dataset
::
CreateReaders
()
{
void
Dataset
::
CreateReaders
()
{
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
VLOG
(
3
)
<<
"thread_num in Readers: "
<<
thread_num_
;
VLOG
(
3
)
<<
"readers size: "
<<
readers_
.
size
();
if
(
readers_
.
size
()
!=
0
)
{
if
(
readers_
.
size
()
!=
0
)
{
return
;
return
;
}
}
VLOG
(
3
)
<<
"data feed class name: "
<<
data_feed_desc_
.
name
();
for
(
int64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
readers_
.
push_back
(
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc_
.
name
()));
readers_
.
push_back
(
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc_
.
name
()));
readers_
.
back
()
->
Init
(
data_feed_desc_
);
readers_
.
back
()
->
Init
(
data_feed_desc_
);
}
}
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
readers_
[
0
]
->
SetFileList
(
filelist_
);
readers_
[
0
]
->
SetFileList
(
filelist_
);
}
}
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
b66f0074
...
@@ -23,12 +23,13 @@ namespace paddle {
...
@@ -23,12 +23,13 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
void
DistMultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
void
DistMultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
data
_
set
)
{
Dataset
*
dataset
)
{
thread_num_
=
trainer_desc
.
thread_num
();
thread_num_
=
trainer_desc
.
thread_num
();
workers_
.
resize
(
thread_num_
);
workers_
.
resize
(
thread_num_
);
dataset
->
CreateReaders
();
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
data
_
set
->
GetReaders
();
dataset
->
GetReaders
();
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
b66f0074
...
@@ -14,8 +14,9 @@ limitations under the License. */
...
@@ -14,8 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include <deque>
#include <deque>
#include <
unordered_set
>
#include <
memory
>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message.h"
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
b66f0074
...
@@ -90,6 +90,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
...
@@ -90,6 +90,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int
batch_cnt
=
0
;
int
batch_cnt
=
0
;
timeline
.
Start
();
timeline
.
Start
();
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
LOG
(
WARNING
)
<<
"read a batch in thread "
<<
thread_id_
;
timeline
.
Pause
();
timeline
.
Pause
();
read_time
+=
timeline
.
ElapsedSec
();
read_time
+=
timeline
.
ElapsedSec
();
total_time
+=
timeline
.
ElapsedSec
();
total_time
+=
timeline
.
ElapsedSec
();
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
b66f0074
...
@@ -26,8 +26,12 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -26,8 +26,12 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_
=
trainer_desc
.
thread_num
();
thread_num_
=
trainer_desc
.
thread_num
();
// get filelist from trainer_desc here
// get filelist from trainer_desc here
workers_
.
resize
(
thread_num_
);
workers_
.
resize
(
thread_num_
);
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
dataset
->
CreateReaders
();
VLOG
(
3
)
<<
"readers created"
;
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
dataset
->
GetReaders
();
dataset
->
GetReaders
();
VLOG
(
3
)
<<
"readers num: "
<<
readers
.
size
();
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
trainer_desc
.
device_worker_name
());
...
@@ -50,6 +54,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
...
@@ -50,6 +54,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
}
}
void
MultiTrainer
::
Run
()
{
void
MultiTrainer
::
Run
()
{
VLOG
(
3
)
<<
"Going to run"
;
for
(
int
thidx
=
0
;
thidx
<
thread_num_
;
++
thidx
)
{
for
(
int
thidx
=
0
;
thidx
<
thread_num_
;
++
thidx
)
{
threads_
.
push_back
(
threads_
.
push_back
(
std
::
thread
(
&
DeviceWorker
::
TrainFiles
,
workers_
[
thidx
].
get
()));
std
::
thread
(
&
DeviceWorker
::
TrainFiles
,
workers_
[
thidx
].
get
()));
...
...
python/paddle/fluid/dataset.py
浏览文件 @
b66f0074
...
@@ -22,7 +22,7 @@ class DatasetFactory(object):
...
@@ -22,7 +22,7 @@ class DatasetFactory(object):
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
create_dataset
(
self
,
datafeed_class
):
def
create_dataset
(
self
,
datafeed_class
=
"QueueDataset"
):
try
:
try
:
dataset
=
globals
()[
datafeed_class
]()
dataset
=
globals
()[
datafeed_class
]()
return
dataset
return
dataset
...
@@ -38,6 +38,7 @@ class DatasetBase(object):
...
@@ -38,6 +38,7 @@ class DatasetBase(object):
self
.
proto_desc
=
data_feed_pb2
.
DataFeedDesc
()
self
.
proto_desc
=
data_feed_pb2
.
DataFeedDesc
()
self
.
proto_desc
.
pipe_command
=
"cat"
self
.
proto_desc
.
pipe_command
=
"cat"
self
.
dataset
=
core
.
Dataset
()
self
.
dataset
=
core
.
Dataset
()
self
.
thread_num
=
0
def
set_pipe_command
(
self
,
pipe_command
):
def
set_pipe_command
(
self
,
pipe_command
):
"""
"""
...
@@ -63,6 +64,7 @@ class DatasetBase(object):
...
@@ -63,6 +64,7 @@ class DatasetBase(object):
def
set_thread
(
self
,
thread_num
):
def
set_thread
(
self
,
thread_num
):
self
.
dataset
.
set_thread_num
(
thread_num
)
self
.
dataset
.
set_thread_num
(
thread_num
)
self
.
thread_num
=
thread_num
def
set_filelist
(
self
,
filelist
):
def
set_filelist
(
self
,
filelist
):
self
.
dataset
.
set_filelist
(
filelist
)
self
.
dataset
.
set_filelist
(
filelist
)
...
@@ -84,6 +86,9 @@ class DatasetBase(object):
...
@@ -84,6 +86,9 @@ class DatasetBase(object):
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
)
def
_prepare_to_run
(
self
):
self
.
dataset
.
set_data_feed_desc
(
self
.
desc
())
def
desc
(
self
):
def
desc
(
self
):
"""
"""
Returns a protobuf message for this DataFeedDesc
Returns a protobuf message for this DataFeedDesc
...
@@ -104,7 +109,7 @@ class InMemoryDataset(DatasetBase):
...
@@ -104,7 +109,7 @@ class InMemoryDataset(DatasetBase):
self
.
proto_desc
.
name
=
"MultiSlotInMemoryDataFeed"
self
.
proto_desc
.
name
=
"MultiSlotInMemoryDataFeed"
def
load_into_memory
(
self
):
def
load_into_memory
(
self
):
self
.
dataset
.
set_data_feed_desc
(
self
.
desc
()
)
_prepare_to_run
(
)
self
.
dataset
.
load_into_memory
()
self
.
dataset
.
load_into_memory
()
def
local_shuffle
(
self
):
def
local_shuffle
(
self
):
...
...
python/paddle/fluid/executor.py
浏览文件 @
b66f0074
...
@@ -23,6 +23,7 @@ from .framework import Program, default_main_program, Variable
...
@@ -23,6 +23,7 @@ from .framework import Program, default_main_program, Variable
from
.
import
core
from
.
import
core
from
.
import
compiler
from
.
import
compiler
from
..
import
compat
as
cpt
from
..
import
compat
as
cpt
from
.trainer_factory
import
TrainerFactory
__all__
=
[
'Executor'
,
'global_scope'
,
'scope_guard'
]
__all__
=
[
'Executor'
,
'global_scope'
,
'scope_guard'
]
...
@@ -616,6 +617,7 @@ class Executor(object):
...
@@ -616,6 +617,7 @@ class Executor(object):
dataset
=
None
,
dataset
=
None
,
fetch_list
=
None
,
fetch_list
=
None
,
scope
=
None
,
scope
=
None
,
thread
=
0
,
opt_info
=
None
):
opt_info
=
None
):
if
scope
is
None
:
if
scope
is
None
:
scope
=
global_scope
()
scope
=
global_scope
()
...
@@ -624,7 +626,14 @@ class Executor(object):
...
@@ -624,7 +626,14 @@ class Executor(object):
compiled
=
isinstance
(
program
,
compiler
.
CompiledProgram
)
compiled
=
isinstance
(
program
,
compiler
.
CompiledProgram
)
if
not
compiled
:
if
not
compiled
:
trainer
=
TrainerFactory
().
create_trainer
(
opt_info
)
trainer
=
TrainerFactory
().
create_trainer
(
opt_info
)
self
.
_default_executor
.
run_from_dataset
(
program_desc
,
if
thread
<=
0
:
trainer
.
set_thread
(
dataset
.
thread_num
)
else
:
trainer
.
set_thread
(
thread
)
dataset
.
_prepare_to_run
()
print
(
"run_from_dataset called"
)
self
.
_default_executor
.
run_from_dataset
(
program
.
desc
,
scope
,
dataset
.
dataset
,
trainer
.
_desc
())
trainer
.
_desc
())
else
:
else
:
# For compiled program, more runtime should be implemented
# For compiled program, more runtime should be implemented
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录