Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e36bbcc8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e36bbcc8
编写于
3月 07, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix some typo and CMakefile.txt
上级
824b84d1
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
51 addition
and
60 deletion
+51
-60
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+18
-18
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+13
-15
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+4
-0
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+2
-6
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-5
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+13
-16
未找到文件。
paddle/fluid/framework/data_set.cc
浏览文件 @
e36bbcc8
...
...
@@ -18,15 +18,14 @@
namespace
paddle
{
namespace
framework
{
Dataset
::
Dataset
()
{
thread_num_
=
1
;
}
Dataset
::
Dataset
()
{
thread_num_
=
1
;
}
void
Dataset
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
filelist_
=
filelist
;
int
file_cnt
=
filelist_
.
size
();
if
(
thread_num_
>
file_cnt
)
{
VLOG
(
1
)
<<
"DataSet thread num = "
<<
thread_num_
<<
", file num = "
<<
file_cnt
VLOG
(
1
)
<<
"DataSet thread num = "
<<
thread_num_
<<
", file num = "
<<
file_cnt
<<
". Changing DataSet thread num = "
<<
file_cnt
;
thread_num_
=
file_cnt
;
}
...
...
@@ -35,22 +34,23 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) {
void
Dataset
::
SetThreadNum
(
int
thread_num
)
{
int
file_cnt
=
filelist_
.
size
();
if
(
file_cnt
!=
0
&&
thread_num
>
file_cnt
)
{
VLOG
(
1
)
<<
"DataSet thread num = "
<<
thread_num
<<
", file num = "
<<
file_cnt
VLOG
(
1
)
<<
"DataSet thread num = "
<<
thread_num
<<
", file num = "
<<
file_cnt
<<
". Changing DataSet thread num = "
<<
file_cnt
;
thread_num
=
file_cnt
;
}
thread_num_
=
thread_num
;
}
void
Dataset
::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
void
Dataset
::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
void
Dataset
::
SetDataFeedDesc
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
{
void
Dataset
::
SetDataFeedDesc
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
{
data_feed_desc_
=
data_feed_desc
;
}
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
Dataset
::
GetReaders
()
{
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
Dataset
::
GetReaders
()
{
return
readers_
;
}
...
...
@@ -60,8 +60,8 @@ void Dataset::LoadIntoMemory() {
}
std
::
vector
<
std
::
thread
>
load_threads
;
for
(
int64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
load_threads
.
push_back
(
std
::
thread
(
&
paddle
::
framework
::
DataFeed
::
LoadIntoMemory
,
readers_
[
i
].
get
()));
load_threads
.
push_back
(
std
::
thread
(
&
paddle
::
framework
::
DataFeed
::
LoadIntoMemory
,
readers_
[
i
].
get
()));
}
for
(
std
::
thread
&
t
:
load_threads
)
{
t
.
join
();
...
...
@@ -74,8 +74,8 @@ void Dataset::LocalShuffle() {
}
std
::
vector
<
std
::
thread
>
local_shuffle_threads
;
for
(
int64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
local_shuffle_threads
.
push_back
(
std
::
thread
(
&
paddle
::
framework
::
DataFeed
::
LocalShuffle
,
readers_
[
i
].
get
()));
local_shuffle_threads
.
push_back
(
std
::
thread
(
&
paddle
::
framework
::
DataFeed
::
LocalShuffle
,
readers_
[
i
].
get
()));
}
for
(
std
::
thread
&
t
:
local_shuffle_threads
)
{
t
.
join
();
...
...
@@ -115,14 +115,14 @@ void Dataset::CreateReaders() {
readers_
[
0
]
->
SetFileList
(
filelist_
);
}
int
Dataset
::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
{
int
Dataset
::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
{
// can also use hash
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
// todo
int64_t
index
=
0
;
readers_
[
index
]
->
PutInsToChannel
(
msg
);
return
0
;
}
}
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/data_set.h
浏览文件 @
e36bbcc8
...
...
@@ -34,29 +34,27 @@ class Dataset {
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
SetDataFeedDesc
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
);
virtual
void
SetDataFeedDesc
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
);
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
{
return
filelist_
;
}
virtual
int
GetThreadNum
()
{
return
thread_num_
;
}
virtual
int
GetTrainerNum
()
{
return
trainer_num_
;
}
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
{
return
filelist_
;
}
virtual
int
GetThreadNum
()
{
return
thread_num_
;
}
virtual
int
GetTrainerNum
()
{
return
trainer_num_
;
}
virtual
const
paddle
::
framework
::
DataFeedDesc
&
GetDataFeedDesc
()
{
return
data_feed_desc_
;
}
virtual
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
GetReaders
();
virtual
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
GetReaders
();
virtual
void
LoadIntoMemory
();
virtual
void
LocalShuffle
();
// todo global shuffle
virtual
void
GlobalShuffle
();
virtual
void
CreateReaders
();
protected:
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
);
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
);
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers_
;
int
thread_num_
;
std
::
string
fs_name_
;
...
...
@@ -66,5 +64,5 @@ class Dataset {
int
trainer_num_
;
};
}
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/executor.cc
浏览文件 @
e36bbcc8
...
...
@@ -115,6 +115,10 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
}
void
Executor
::
RunFromDataset
(
const
ProgramDesc
&
pdesc
,
const
Dataset
&
dataset
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
)
{}
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
,
...
...
paddle/fluid/framework/executor.h
浏览文件 @
e36bbcc8
...
...
@@ -19,13 +19,13 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/data_set.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -112,11 +112,7 @@ class Executor {
void
EnableMKLDNN
(
const
ProgramDesc
&
program
);
void
RunFromTrainerDesc
(
const
ProgramDesc
&
main_program
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
);
void
RunFromDataset
(
const
ProgramDesc
&
main_program
,
Dataset
*
dataset
,
void
RunFromDataset
(
const
ProgramDesc
&
main_program
,
const
Dataset
&
dataset
,
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
);
public:
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
e36bbcc8
...
...
@@ -5,11 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wr
if
(
WITH_PYTHON
)
list
(
APPEND PYBIND_DEPS py_func_op
)
endif
()
<<<<<<< HEAD
set
(
PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc
)
=======
set
(
PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc fleet_wrapper_py.cc imperative.cc ir.cc inference_api.cc
)
>>>>>>> add pybind for fleet
set
(
PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc fleet_wrapper_py.cc data_set_py.cc imperative.cc ir.cc inference_api.cc
)
if
(
WITH_PYTHON
)
if
(
WITH_AMD_GPU
)
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
e36bbcc8
...
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
...
...
@@ -29,12 +27,12 @@ limitations under the License. */
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/pybind/data_set_py.h"
namespace
py
=
pybind11
;
namespace
pd
=
paddle
::
framework
;
...
...
@@ -43,10 +41,9 @@ namespace paddle {
namespace
pybind
{
void
BindDataset
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
Data
S
et
>
(
*
m
,
"Dataset"
)
py
::
class_
<
framework
::
Data
s
et
>
(
*
m
,
"Dataset"
)
.
def
(
py
::
init
([]()
{
return
std
::
unique_ptr
<
framework
::
Dataset
>
(
new
framework
::
Dataset
());
return
std
::
unique_ptr
<
framework
::
Dataset
>
(
new
framework
::
Dataset
());
}))
.
def
(
"set_filelist"
,
&
framework
::
Dataset
::
SetFileList
)
.
def
(
"set_thread_num"
,
&
framework
::
Dataset
::
SetThreadNum
)
...
...
@@ -54,7 +51,7 @@ void BindDataset(py::module* m) {
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"load_into_memory"
,
&
framework
::
Dataset
::
LoadIntoMemory
)
.
def
(
"local_shuffle"
,
&
framework
::
Dataset
::
LocalShuffle
)
.
def
(
"global_shuffle"
,
&
framework
::
Dataset
::
GLobalShuffle
)
.
def
(
"global_shuffle"
,
&
framework
::
Dataset
::
GlobalShuffle
);
}
}
// end namespace pybind
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录