Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
dd67ad08
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dd67ad08
编写于
3月 09, 2019
作者:
X
xjqbest
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify c++ and python dataset related code & fix bug
上级
cc4def6b
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
74 addition
and
32 deletion
+74
-32
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+12
-0
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+5
-2
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+6
-3
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+1
-2
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+1
-1
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+2
-4
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+3
-1
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+1
-1
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+3
-3
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+3
-0
python/paddle/fluid/data_feed_desc.py
python/paddle/fluid/data_feed_desc.py
+0
-4
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+24
-10
python/paddle/fluid/distributed/ps_instance.py
python/paddle/fluid/distributed/ps_instance.py
+12
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
dd67ad08
...
...
@@ -206,7 +206,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer
)
variable_helper timer
fs shell
)
cc_test
(
data_feed_test SRCS data_feed_test.cc DEPS async_executor
)
...
...
paddle/fluid/framework/async_executor.cc
浏览文件 @
dd67ad08
...
...
@@ -59,6 +59,12 @@ void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
fleet_ptr_
->
GatherServers
(
host_sign_list
,
node_num
);
}
// todo InitModel
void
AsyncExecutor
::
InitModel
()
{
}
// todo SaveModel
void
AsyncExecutor
::
SaveModel
(
const
std
::
string
&
path
)
{
}
void
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
const
std
::
string
&
data_feed_desc_str
,
const
std
::
vector
<
std
::
string
>&
filelist
,
...
...
@@ -154,5 +160,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return
;
}
// todo RunFromDataset
void
AsyncExecutor
::
RunFromDataset
(
const
ProgramDesc
&
main_program
,
Dataset
*
data_set
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
)
{
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/data_feed.cc
浏览文件 @
dd67ad08
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h"
#include <stdio_ext.h>
#include <utility>
#include "gflags/gflags.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
...
...
@@ -135,6 +136,7 @@ int PrivateQueueDataFeed<T>::Next() {
return
batch_size_
;
}
// explicit instantiation
template
class
PrivateQueueDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
template
<
typename
T
>
...
...
@@ -220,8 +222,6 @@ void InMemoryDataFeed<T>::LocalShuffle() {
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
}
template
class
InMemoryDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
// todo global shuffle
/*
template <typename T>
...
...
@@ -242,6 +242,9 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
}
*/
// explicit instantiation
template
class
InMemoryDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
void
MultiSlotDataFeed
::
Init
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
{
finish_init_
=
false
;
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
dd67ad08
...
...
@@ -12,6 +12,9 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h"
...
...
@@ -44,9 +47,9 @@ void Dataset::SetThreadNum(int thread_num) {
void
Dataset
::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
void
Dataset
::
SetDataFeedDesc
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
{
data_feed_desc_
=
data_feed_desc
;
void
Dataset
::
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
{
google
::
protobuf
::
TextFormat
::
ParseFromString
(
data_feed_desc_str
,
&
data_feed_desc_
)
;
}
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
...
...
paddle/fluid/framework/data_set.h
浏览文件 @
dd67ad08
...
...
@@ -34,8 +34,7 @@ class Dataset {
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
SetDataFeedDesc
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
);
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
);
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
{
return
filelist_
;
}
virtual
int
GetThreadNum
()
{
return
thread_num_
;
}
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
dd67ad08
...
...
@@ -22,7 +22,7 @@ namespace paddle {
namespace
framework
{
void
DistMultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
data_set
)
{
Dataset
*
data_set
)
{
thread_num_
=
trainer_desc
.
thread_num
();
workers_
.
resize
(
thread_num_
);
readers_
.
resize
(
thread_num_
);
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
dd67ad08
...
...
@@ -14,11 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
...
...
@@ -119,7 +117,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
void
Executor
::
RunFromDataset
(
const
ProgramDesc
&
main_program
,
const
Dataset
&
dataset
,
Dataset
*
dataset
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
)
{
VLOG
(
3
)
<<
"Start to RunFromDataset in executor"
;
...
...
paddle/fluid/framework/executor.h
浏览文件 @
dd67ad08
...
...
@@ -19,6 +19,8 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include <unordered_map>
#include <memory>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
...
...
@@ -112,7 +114,7 @@ class Executor {
void
EnableMKLDNN
(
const
ProgramDesc
&
program
);
void
RunFromDataset
(
const
ProgramDesc
&
main_program
,
const
Dataset
&
dataset
,
void
RunFromDataset
(
const
ProgramDesc
&
main_program
,
Dataset
*
dataset
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
);
public:
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
dd67ad08
...
...
@@ -22,7 +22,7 @@ namespace paddle {
namespace
framework
{
void
MultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
dataset
)
{
Dataset
*
dataset
)
{
thread_num_
=
trainer_desc
.
thread_num
();
// get filelist from trainer_desc here
workers_
.
resize
(
thread_num_
);
...
...
paddle/fluid/framework/trainer.h
浏览文件 @
dd67ad08
...
...
@@ -42,7 +42,7 @@ class TrainerBase {
void
SetScope
(
Scope
*
root_scope
);
void
SetDebug
(
const
bool
debug
)
{
debug_
=
debug
;
}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
data_set
)
=
0
;
Dataset
*
data_set
)
=
0
;
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
)
=
0
;
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
=
0
;
...
...
@@ -62,7 +62,7 @@ class MultiTrainer : public TrainerBase {
MultiTrainer
()
{}
virtual
~
MultiTrainer
()
{}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
data_set
);
Dataset
*
data_set
);
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
);
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{}
...
...
@@ -81,7 +81,7 @@ class DistMultiTrainer : public MultiTrainer {
DistMultiTrainer
()
{}
virtual
~
DistMultiTrainer
()
{}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
const
Dataset
&
data_set
);
Dataset
*
data_set
);
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
);
virtual
void
Finalize
();
...
...
python/paddle/fluid/__init__.py
浏览文件 @
dd67ad08
...
...
@@ -24,6 +24,9 @@ from .executor import *
from
.
import
data_feed_desc
from
.data_feed_desc
import
*
from
.
import
dataset
from
.dataset
import
*
from
.
import
async_executor
from
.async_executor
import
*
...
...
python/paddle/fluid/data_feed_desc.py
浏览文件 @
dd67ad08
...
...
@@ -139,10 +139,6 @@ class DataFeedDesc(object):
self
.
proto_desc
.
multi_slot_desc
.
slots
[
self
.
__name_to_index
[
name
]].
is_used
=
True
def
global_shuffle
(
self
):
self
.
data
.
global_shuffle
()
pass
def
desc
(
self
):
"""
Returns a protobuf message for this DataFeedDesc
...
...
python/paddle/fluid/dataset.py
浏览文件 @
dd67ad08
...
...
@@ -23,9 +23,9 @@ class DatasetFactory(object):
pass
def
create_dataset
(
self
,
datafeed_class
):
datafeed_class
=
datafeed_class
.
capitalize
()
try
:
dataset
=
globals
()[
datafeed_class
]()
return
dataset
except
:
raise
ValueError
(
"datafeed class %s does not exist"
%
datafeed_class
)
...
...
@@ -37,6 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance
self
.
proto_desc
=
data_feed_pb2
.
DataFeedDesc
()
self
.
proto_desc
.
pipe_command
=
"cat"
self
.
dataset
=
core
.
Dataset
()
def
set_pipe_command
(
self
,
pipe_command
):
"""
...
...
@@ -60,17 +61,23 @@ class DatasetBase(object):
"""
self
.
proto_desc
.
batch_size
=
batch_size
def
set_thread
(
self
,
thread_num
):
self
.
dataset
.
set_thread_num
(
thread_num
)
def
set_filelist
(
self
,
filelist
):
self
.
dataset
.
set_filelist
(
filelist
)
def
set_use_var
(
self
,
var_list
):
multi_slot
=
self
.
proto_desc
.
multi_slot_desc
()
multi_slot
=
self
.
proto_desc
.
multi_slot_desc
for
var
in
var_list
:
slot_var
=
multi_slot
.
add
()
slot_var
=
multi_slot
.
slots
.
add
()
slot_var
.
is_used
=
True
slot_var
.
name
=
var
.
name
if
var
.
lod_level
==
0
:
slot_var
.
is_dense
=
True
if
var
.
dtype
==
core
.
VarType
.
FP32
:
if
var
.
dtype
==
core
.
Var
Desc
.
Var
Type
.
FP32
:
slot_var
.
type
=
"float32"
elif
var
.
dtype
==
core
.
VarType
.
INT64
:
elif
var
.
dtype
==
core
.
Var
Desc
.
Var
Type
.
INT64
:
slot_var
.
type
=
"uint64"
else
:
raise
ValueError
(
...
...
@@ -93,17 +100,24 @@ class DatasetBase(object):
class
InMemoryDataset
(
DatasetBase
):
def
__init__
(
self
):
super
(
InMemoryDataset
.
__init__
())
self
.
proto_desc
.
name
=
"InMemoryDataFeed"
super
(
InMemoryDataset
,
self
).
__init__
()
self
.
proto_desc
.
name
=
"MultiSlotInMemoryDataFeed"
def
load_into_memory
(
self
):
self
.
dataset
.
set_data_feed_desc
(
self
.
desc
())
self
.
dataset
.
load_into_memory
()
def
local_shuffle
(
self
):
pass
self
.
dataset
.
local_shuffle
()
def
global_shuffle
(
self
):
pass
from
.distributed
import
ps_instance
instance
=
ps_instance
.
PaddlePSInstance
(
1
,
2
)
self
.
dataset
.
set_trainer_num
(
instance
.
get_worker_num
())
self
.
global_shuffle
()
class
QueueDataset
(
DatasetBase
):
def
__init__
(
self
):
super
(
QueueDataset
.
__init__
()
)
super
(
QueueDataset
,
self
).
__init__
(
)
self
.
proto_desc
.
name
=
"MultiSlotDataFeed"
python/paddle/fluid/distributed/ps_instance.py
浏览文件 @
dd67ad08
...
...
@@ -121,6 +121,18 @@ class PaddlePSInstance(object):
"""
return
self
.
_nodes
def
get_worker_num
(
self
):
"""
Return worker num
"""
return
self
.
_worker_num
def
get_server_num
(
self
):
"""
Return server num
"""
return
self
.
_server_num
def
barrier_all
(
self
):
"""
barrier workers and servers
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录