Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
daba57f7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
daba57f7
编写于
12月 04, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
complete ctr_reader
上级
9f53aad1
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
143 addition
and
45 deletion
+143
-45
paddle/fluid/operators/reader/create_ctr_reader_op.cc
paddle/fluid/operators/reader/create_ctr_reader_op.cc
+3
-2
paddle/fluid/operators/reader/ctr_reader.cc
paddle/fluid/operators/reader/ctr_reader.cc
+3
-3
paddle/fluid/operators/reader/ctr_reader.h
paddle/fluid/operators/reader/ctr_reader.h
+34
-2
paddle/fluid/operators/reader/read_op.cc
paddle/fluid/operators/reader/read_op.cc
+25
-16
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+21
-13
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-0
python/paddle/fluid/contrib/__init__.py
python/paddle/fluid/contrib/__init__.py
+3
-0
python/paddle/fluid/contrib/reader/__init__.py
python/paddle/fluid/contrib/reader/__init__.py
+19
-0
python/paddle/fluid/contrib/reader/ctr_reader.py
python/paddle/fluid/contrib/reader/ctr_reader.py
+33
-9
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
paddle/fluid/operators/reader/create_ctr_reader_op.cc
浏览文件 @
daba57f7
...
@@ -51,6 +51,7 @@ class CreateCTRReaderOp : public framework::OperatorBase {
...
@@ -51,6 +51,7 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto
file_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"file_list"
);
auto
file_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"file_list"
);
DataDesc
data_desc
(
batch_size
,
file_list
,
file_type
,
file_format
,
DataDesc
data_desc
(
batch_size
,
file_list
,
file_type
,
file_format
,
dense_slot_index
,
sparse_slot_index
,
sparse_slots
);
dense_slot_index
,
sparse_slot_index
,
sparse_slots
);
VLOG
(
1
)
<<
data_desc
;
out
->
Reset
(
std
::
make_shared
<
CTRReader
>
(
queue_holder
->
GetQueue
(),
thread_num
,
out
->
Reset
(
std
::
make_shared
<
CTRReader
>
(
queue_holder
->
GetQueue
(),
thread_num
,
data_desc
));
data_desc
));
}
}
...
@@ -69,10 +70,10 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
...
@@ -69,10 +70,10 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
"The list of files that need to read"
);
"The list of files that need to read"
);
AddAttr
<
std
::
vector
<
int
>>
(
AddAttr
<
std
::
vector
<
int
>>
(
"dense_slot_index"
,
"dense_slot_index"
,
"the
spar
se slots id that should be extract from file"
)
"the
den
se slots id that should be extract from file"
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
std
::
vector
<
int
>>
(
AddAttr
<
std
::
vector
<
int
>>
(
"
den
se_slot_index"
,
"
spar
se_slot_index"
,
"the sparse slots id that should be extract from file"
)
"the sparse slots id that should be extract from file"
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"sparse_slots"
,
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"sparse_slots"
,
...
...
paddle/fluid/operators/reader/ctr_reader.cc
浏览文件 @
daba57f7
...
@@ -157,8 +157,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
...
@@ -157,8 +157,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
}
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
VLOG
(
3
)
<<
"all reader thread is stopped,
push empty data into
queue"
;
VLOG
(
3
)
<<
"all reader thread is stopped,
close the
queue"
;
queue
->
Push
({}
);
queue
->
Close
(
);
VLOG
(
3
)
<<
"monitor thread exited"
;
VLOG
(
3
)
<<
"monitor thread exited"
;
}
}
...
@@ -247,7 +247,7 @@ static inline void parse_csv_line(
...
@@ -247,7 +247,7 @@ static inline void parse_csv_line(
int
slot_idx
=
data_desc
.
dense_slot_index_
[
i
];
int
slot_idx
=
data_desc
.
dense_slot_index_
[
i
];
auto
&
slot_data
=
ret
[
slot_idx
];
auto
&
slot_data
=
ret
[
slot_idx
];
std
::
vector
<
std
::
string
>
data_in_slot_str
;
std
::
vector
<
std
::
string
>
data_in_slot_str
;
string_split
(
ret
[
slot_idx
]
,
','
,
&
data_in_slot_str
);
string_split
(
slot_data
,
','
,
&
data_in_slot_str
);
std
::
vector
<
float
>
data_in_slot
;
std
::
vector
<
float
>
data_in_slot
;
for
(
auto
&
data_str
:
data_in_slot_str
)
{
for
(
auto
&
data_str
:
data_in_slot_str
)
{
(
*
dense_datas
)[
i
].
push_back
(
std
::
stof
(
data_str
));
(
*
dense_datas
)[
i
].
push_back
(
std
::
stof
(
data_str
));
...
...
paddle/fluid/operators/reader/ctr_reader.h
浏览文件 @
daba57f7
...
@@ -60,6 +60,35 @@ struct DataDesc {
...
@@ -60,6 +60,35 @@ struct DataDesc {
const
std
::
vector
<
std
::
string
>
sparse_slot_ids_
;
const
std
::
vector
<
std
::
string
>
sparse_slot_ids_
;
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
DataDesc
&
data_desc
)
{
os
<<
"data_desc:
\n
"
;
os
<<
"
\t
batch_size -> "
<<
data_desc
.
batch_size_
<<
"
\n
"
;
os
<<
"
\t
file_type -> "
<<
data_desc
.
file_type_
<<
"
\n
"
;
os
<<
"
\t
file_format -> "
<<
data_desc
.
file_format_
<<
"
\n
"
;
os
<<
"
\t
file_names -> {"
;
for
(
auto
&
file_name
:
data_desc
.
file_names_
)
{
os
<<
file_name
<<
","
;
}
os
<<
"}
\n
"
;
os
<<
"
\t
dense_slot_index -> {"
;
for
(
auto
&
slot
:
data_desc
.
dense_slot_index_
)
{
os
<<
slot
<<
","
;
}
os
<<
"}
\n
"
;
os
<<
"
\t
sparse_slot_index_ -> {"
;
for
(
auto
&
slot
:
data_desc
.
sparse_slot_index_
)
{
os
<<
slot
<<
","
;
}
os
<<
"}
\n
"
;
os
<<
"
\t
sparse_slot_ids_ -> {"
;
for
(
auto
&
slot
:
data_desc
.
sparse_slot_ids_
)
{
os
<<
slot
<<
","
;
}
os
<<
"}
\n
"
;
return
os
;
}
void
ReadThread
(
const
std
::
vector
<
std
::
string
>&
file_list
,
void
ReadThread
(
const
std
::
vector
<
std
::
string
>&
file_list
,
const
DataDesc
&
data_desc
,
int
thread_id
,
const
DataDesc
&
data_desc
,
int
thread_id
,
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
...
@@ -89,7 +118,7 @@ class CTRReader : public framework::FileReader {
...
@@ -89,7 +118,7 @@ class CTRReader : public framework::FileReader {
}
}
}
}
~
CTRReader
()
{}
~
CTRReader
()
{
Shutdown
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
bool
success
;
bool
success
;
...
@@ -106,7 +135,10 @@ class CTRReader : public framework::FileReader {
...
@@ -106,7 +135,10 @@ class CTRReader : public framework::FileReader {
for
(
auto
&
read_thread
:
read_threads_
)
{
for
(
auto
&
read_thread
:
read_threads_
)
{
read_thread
->
join
();
read_thread
->
join
();
}
}
if
(
monitor_thread_
)
{
monitor_thread_
->
join
();
monitor_thread_
->
join
();
}
read_threads_
.
clear
();
read_threads_
.
clear
();
monitor_thread_
.
reset
(
nullptr
);
monitor_thread_
.
reset
(
nullptr
);
...
...
paddle/fluid/operators/reader/read_op.cc
浏览文件 @
daba57f7
...
@@ -27,15 +27,16 @@ class ReadInferShape : public framework::InferShapeBase {
...
@@ -27,15 +27,16 @@ class ReadInferShape : public framework::InferShapeBase {
"The ReadOp must take a reader as input."
);
"The ReadOp must take a reader as input."
);
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
"Out"
),
"The ReadOp should be assigned with output."
);
"The ReadOp should be assigned with output."
);
if
(
!
ctx
->
IsRuntime
()
&&
ctx
->
Attrs
().
Get
<
bool
>
(
"infer_out"
))
{
std
::
vector
<
framework
::
DDim
>
reader_dims
=
ctx
->
GetReaderDims
(
"Reader"
);
std
::
vector
<
framework
::
DDim
>
reader_dims
=
ctx
->
GetReaderDims
(
"Reader"
);
std
::
vector
<
std
::
string
>
out_names
=
ctx
->
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
out_names
=
ctx
->
Outputs
(
"Out"
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
reader_dims
.
size
(),
out_names
.
size
(),
reader_dims
.
size
(),
out_names
.
size
(),
"The reader's dim number doesn't match the output number."
);
"The reader's dim number doesn't match the output number."
);
ctx
->
SetOutputsDim
(
"Out"
,
reader_dims
);
ctx
->
SetOutputsDim
(
"Out"
,
reader_dims
);
if
(
!
ctx
->
IsRuntime
())
{
auto
in_desc
=
auto
in_desc
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetInputVarPtrs
(
"Reader"
)[
0
]);
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetInputVarPtrs
(
"Reader"
)[
0
]);
std
::
cout
<<
in_desc
->
Proto
()
->
SerializeAsString
()
<<
std
::
endl
;
auto
in_lod_levels
=
in_desc
->
GetLoDLevels
();
auto
in_lod_levels
=
in_desc
->
GetLoDLevels
();
auto
out_var_ptrs
=
ctx
->
GetOutputVarPtrs
(
"Out"
);
auto
out_var_ptrs
=
ctx
->
GetOutputVarPtrs
(
"Out"
);
PADDLE_ENFORCE_EQ
(
in_lod_levels
.
size
(),
out_var_ptrs
.
size
(),
PADDLE_ENFORCE_EQ
(
in_lod_levels
.
size
(),
out_var_ptrs
.
size
(),
...
@@ -53,6 +54,8 @@ class ReadInferVarType : public framework::VarTypeInference {
...
@@ -53,6 +54,8 @@ class ReadInferVarType : public framework::VarTypeInference {
public:
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
bool
infer_out
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"infer_out"
));
if
(
infer_out
)
{
std
::
string
reader_name
=
op_desc
.
Input
(
"Reader"
)[
0
];
std
::
string
reader_name
=
op_desc
.
Input
(
"Reader"
)[
0
];
std
::
vector
<
std
::
string
>
out_names
=
op_desc
.
Output
(
"Out"
);
std
::
vector
<
std
::
string
>
out_names
=
op_desc
.
Output
(
"Out"
);
framework
::
VarDesc
*
reader
=
block
->
FindVarRecursive
(
reader_name
);
framework
::
VarDesc
*
reader
=
block
->
FindVarRecursive
(
reader_name
);
...
@@ -64,6 +67,7 @@ class ReadInferVarType : public framework::VarTypeInference {
...
@@ -64,6 +67,7 @@ class ReadInferVarType : public framework::VarTypeInference {
out
.
SetDataType
(
dtypes
[
i
]);
out
.
SetDataType
(
dtypes
[
i
]);
}
}
}
}
}
};
};
class
ReadOp
:
public
framework
::
OperatorBase
{
class
ReadOp
:
public
framework
::
OperatorBase
{
...
@@ -73,6 +77,7 @@ class ReadOp : public framework::OperatorBase {
...
@@ -73,6 +77,7 @@ class ReadOp : public framework::OperatorBase {
private:
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
VLOG
(
3
)
<<
"read op in"
;
framework
::
ReaderHolder
*
reader
=
framework
::
ReaderHolder
*
reader
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"Reader"
)),
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"Reader"
)),
"Cannot find reader variable %s"
,
Input
(
"Reader"
))
"Cannot find reader variable %s"
,
Input
(
"Reader"
))
...
@@ -87,7 +92,9 @@ class ReadOp : public framework::OperatorBase {
...
@@ -87,7 +92,9 @@ class ReadOp : public framework::OperatorBase {
reader
->
ReadNext
(
&
ins
);
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
if
(
ins
.
empty
())
{
VLOG
(
3
)
<<
"read empty data in"
;
if
(
Attr
<
bool
>
(
"throw_eof_exp"
))
{
if
(
Attr
<
bool
>
(
"throw_eof_exp"
))
{
VLOG
(
3
)
<<
"throw_eof_exp"
;
PADDLE_THROW_EOF
();
PADDLE_THROW_EOF
();
}
else
{
}
else
{
ins
.
resize
(
out_arg_names
.
size
());
ins
.
resize
(
out_arg_names
.
size
());
...
@@ -96,6 +103,7 @@ class ReadOp : public framework::OperatorBase {
...
@@ -96,6 +103,7 @@ class ReadOp : public framework::OperatorBase {
tensor
.
mutable_data
<
float
>
(
framework
::
make_ddim
({
0
}),
dev_place
);
tensor
.
mutable_data
<
float
>
(
framework
::
make_ddim
({
0
}),
dev_place
);
}
}
}
}
VLOG
(
3
)
<<
"read empty data out"
;
}
}
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
for
(
size_t
i
=
0
;
i
<
out_arg_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
out_arg_names
.
size
();
++
i
)
{
...
@@ -120,6 +128,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -120,6 +128,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" only when the data-balance is enabled in ParallelExecutor"
" only when the data-balance is enabled in ParallelExecutor"
" and it is set by ParallelExecutor instance, not users."
)
" and it is set by ParallelExecutor instance, not users."
)
.
SetDefault
(
true
);
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"infer_out"
,
""
).
SetDefault
(
true
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Read Operator
Read Operator
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
daba57f7
...
@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() {
...
@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() {
"It means the reader will generate two data each time,"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."
);
"whose shapes are [2,3,4] and [5,6] respectively."
);
AddAttr
<
std
::
vector
<
int
>>
(
"lod_levels"
,
"The LoD levels of each data."
);
AddAttr
<
std
::
vector
<
int
>>
(
"lod_levels"
,
"The LoD levels of each data."
);
AddAttr
<
bool
>
(
"use_data_config"
,
"Use the config of all datas like shape_concat/ranks/lod_levels"
)
.
SetDefault
(
true
);
Apply
();
Apply
();
}
}
...
@@ -75,7 +79,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
...
@@ -75,7 +79,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output file reader should not be null."
);
"The output file reader should not be null."
);
const
auto
shape_concat
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape_concat"
);
bool
use_data_config
=
ctx
->
Attrs
().
Get
<
bool
>
(
"use_data_config"
);
if
(
use_data_config
)
{
const
auto
shape_concat
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
ranks
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ranks"
);
const
auto
ranks
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ranks"
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
ctx
->
SetReaderDims
(
"Out"
,
shapes
);
ctx
->
SetReaderDims
(
"Out"
,
shapes
);
...
@@ -88,6 +95,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
...
@@ -88,6 +95,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
framework
::
VarDesc
*
reader
=
framework
::
VarDesc
*
reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
reader
->
SetLoDLevels
(
lod_levels
);
reader
->
SetLoDLevels
(
lod_levels
);
}
}
}
void
FileReaderInferVarType
::
operator
()(
const
framework
::
OpDesc
&
op_desc
,
void
FileReaderInferVarType
::
operator
()(
const
framework
::
OpDesc
&
op_desc
,
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
daba57f7
...
@@ -364,6 +364,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -364,6 +364,7 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"start"
,
&
framework
::
ReaderHolder
::
Start
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
ResetAll
);
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
ResetAll
);
using
LoDTensorBlockingQueue
=
using
LoDTensorBlockingQueue
=
...
...
python/paddle/fluid/contrib/__init__.py
浏览文件 @
daba57f7
...
@@ -22,9 +22,12 @@ from . import op_frequence
...
@@ -22,9 +22,12 @@ from . import op_frequence
from
.op_frequence
import
*
from
.op_frequence
import
*
from
.
import
quantize
from
.
import
quantize
from
.quantize
import
*
from
.quantize
import
*
from
.
import
reader
from
.reader
import
*
__all__
=
[]
__all__
=
[]
__all__
+=
decoder
.
__all__
__all__
+=
decoder
.
__all__
__all__
+=
memory_usage_calc
.
__all__
__all__
+=
memory_usage_calc
.
__all__
__all__
+=
op_frequence
.
__all__
__all__
+=
op_frequence
.
__all__
__all__
+=
quantize
.
__all__
__all__
+=
quantize
.
__all__
__all__
+=
reader
.
__all__
python/paddle/fluid/contrib/reader/__init__.py
0 → 100644
浏览文件 @
daba57f7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
.
import
ctr_reader
__all__
=
ctr_reader
.
__all__
python/paddle/fluid/contrib/reader/ctr_reader.py
浏览文件 @
daba57f7
...
@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \
...
@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \
default_startup_program
,
Variable
default_startup_program
,
Variable
from
paddle.fluid.unique_name
import
generate
as
unique_name
from
paddle.fluid.unique_name
import
generate
as
unique_name
__all__
=
[
'ctr_reader'
]
def
monkey_patch_reader_methods
(
reader
):
def
monkey_patch_reader_methods
(
reader
):
def
__get_reader__
():
def
__get_reader__
():
...
@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader):
...
@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader):
def
reset
():
def
reset
():
return
__get_reader__
().
reset
()
return
__get_reader__
().
reset
()
def
start
():
return
__get_reader__
().
start
()
reader
.
reset
=
reset
reader
.
reset
=
reset
reader
.
start
=
start
reader
.
stop_gradient
=
True
reader
.
stop_gradient
=
True
reader
.
persistable
=
True
reader
.
persistable
=
True
return
reader
return
reader
...
@@ -44,7 +50,12 @@ def _copy_reader_var_(block, var):
...
@@ -44,7 +50,12 @@ def _copy_reader_var_(block, var):
return
new_var
return
new_var
def
ctr_reader
(
feed_data
,
def
ctr_reader
(
feed_dict
,
file_type
,
# gzip or plain
file_format
,
# csv or svm
dense_slot_indexs
,
sparse_slot_indexs
,
capacity
,
capacity
,
thread_num
,
thread_num
,
batch_size
,
batch_size
,
...
@@ -99,12 +110,22 @@ def ctr_reader(feed_data,
...
@@ -99,12 +110,22 @@ def ctr_reader(feed_data,
inputs
=
{
'blocking_queue'
:
[
queue_name
]},
inputs
=
{
'blocking_queue'
:
[
queue_name
]},
outputs
=
{
'Out'
:
[
reader_var
]},
outputs
=
{
'Out'
:
[
reader_var
]},
attrs
=
{
attrs
=
{
'use_data_config'
:
False
,
'thread_num'
:
thread_num
,
'thread_num'
:
thread_num
,
'batch_size'
:
batch_size
,
'batch_size'
:
batch_size
,
'file_list'
:
file_list
,
'file_list'
:
file_list
,
'slots'
:
slots
,
'file_type'
:
file_type
,
'file_format'
:
file_format
,
'dense_slot_index'
:
dense_slot_indexs
,
'sparse_slot_index'
:
sparse_slot_indexs
,
'sparse_slots'
:
slots
,
'ranks'
:
[],
'lod_levels'
:
[],
'shape_concat'
:
[]
})
})
dtypes
=
[
data
.
dtype
for
data
in
feed_dict
]
reader_var
.
desc
.
set_dtypes
(
dtypes
)
reader_var
.
persistable
=
True
reader_var
.
persistable
=
True
main_prog_reader_var
=
_copy_reader_var_
(
main_prog_reader_var
=
_copy_reader_var_
(
...
@@ -118,6 +139,9 @@ def ctr_reader(feed_data,
...
@@ -118,6 +139,9 @@ def ctr_reader(feed_data,
main_blk
=
default_main_program
().
current_block
()
main_blk
=
default_main_program
().
current_block
()
main_blk
.
append_op
(
main_blk
.
append_op
(
type
=
'read'
,
inputs
=
{
'Reader'
:
[
reader
]},
outputs
=
{
'Out'
:
feed_data
})
type
=
'read'
,
inputs
=
{
'Reader'
:
[
reader
]},
attrs
=
{
'infer_out'
:
False
},
outputs
=
{
'Out'
:
feed_dict
})
return
reader
return
reader
python/setup.py.in
浏览文件 @
daba57f7
...
@@ -107,6 +107,7 @@ packages=['paddle',
...
@@ -107,6 +107,7 @@ packages=['paddle',
'paddle.fluid.contrib',
'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.reader',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
'paddle.fluid.transpiler.details']
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录