Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
95ba4bd2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
95ba4bd2
编写于
10月 29, 2019
作者:
H
Huihuang Zheng
提交者:
GitHub
10月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add shape and type check at read_op (#20754)
上级
bb8d7783
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
220 addition
and
16 deletion
+220
-16
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+58
-2
paddle/fluid/framework/reader_test.cc
paddle/fluid/framework/reader_test.cc
+20
-1
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+34
-1
paddle/fluid/operators/reader/py_reader.cc
paddle/fluid/operators/reader/py_reader.cc
+6
-2
paddle/fluid/operators/reader/py_reader.h
paddle/fluid/operators/reader/py_reader.h
+5
-1
paddle/fluid/operators/reader/read_op.cc
paddle/fluid/operators/reader/read_op.cc
+45
-2
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+15
-0
paddle/fluid/pybind/reader_py.cc
paddle/fluid/pybind/reader_py.cc
+14
-2
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+6
-1
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+16
-2
python/paddle/fluid/tests/unittests/test_py_reader_pin_memory.py
...paddle/fluid/tests/unittests/test_py_reader_pin_memory.py
+1
-2
未找到文件。
paddle/fluid/framework/reader.h
浏览文件 @
95ba4bd2
...
...
@@ -20,6 +20,7 @@
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
...
...
@@ -28,6 +29,20 @@ namespace framework {
class
ReaderBase
{
public:
explicit
ReaderBase
(
const
std
::
vector
<
DDim
>&
shapes
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
var_types
,
const
std
::
vector
<
bool
>&
need_check_feed
)
:
shapes_
(
shapes
),
var_types_
(
var_types
),
need_check_feed_
(
need_check_feed
)
{
PADDLE_ENFORCE_EQ
(
shapes_
.
size
(),
need_check_feed_
.
size
(),
"Construct ReaderBase with mismatched sizes of shapes "
"and need_check_feed"
);
PADDLE_ENFORCE_EQ
(
var_types_
.
size
(),
need_check_feed_
.
size
(),
"Construct ReaderBase with mismatched sizes of var_types "
"and need_check_feed"
);
}
virtual
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
);
virtual
void
Shutdown
();
...
...
@@ -38,6 +53,18 @@ class ReaderBase {
// they are readers just before read op.
std
::
unordered_set
<
ReaderBase
*>
GetEndPoints
();
// Returns the shapes of the feeded variables
const
std
::
vector
<
DDim
>&
Shapes
()
const
{
return
shapes_
;
}
// Returns the dtypes of the feeded variables
const
std
::
vector
<
proto
::
VarType
::
Type
>&
VarTypes
()
const
{
return
var_types_
;
}
// For Backward compatibility, old fluid.layers.data doesn't check shape.
// This function returns whether you have the check shape for this Reader.
const
std
::
vector
<
bool
>&
NeedCheckFeed
()
const
{
return
need_check_feed_
;
}
virtual
~
ReaderBase
();
protected:
...
...
@@ -53,6 +80,17 @@ class ReaderBase {
mutable
std
::
mutex
mu_
;
// The shapes of the feeded variables.
std
::
vector
<
DDim
>
shapes_
;
// The dtypes of the feeded variables.
std
::
vector
<
proto
::
VarType
::
Type
>
var_types_
;
// Whether to check the shape and dtype of feeded variables.
// For Backward compatibility, variables created by old API fluid.layers.data
// doesn't check shape but fluid.data checks.
std
::
vector
<
bool
>
need_check_feed_
;
private:
friend
class
DecoratedReader
;
// These methods can be only invoked inside DecoratedReader to record the
...
...
@@ -67,7 +105,9 @@ class DecoratedReader : public ReaderBase,
public
std
::
enable_shared_from_this
<
DecoratedReader
>
{
public:
explicit
DecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
ReaderBase
(),
reader_
(
reader
)
{
:
ReaderBase
(
reader
->
Shapes
(),
reader
->
VarTypes
(),
reader
->
NeedCheckFeed
()),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
...
...
@@ -89,7 +129,13 @@ class DecoratedReader : public ReaderBase,
};
// FileReader is just a conceptual class.
class
FileReader
:
public
ReaderBase
{};
class
FileReader
:
public
ReaderBase
{
public:
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
shapes
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
var_types
,
const
std
::
vector
<
bool
>&
need_check_feed
)
:
ReaderBase
(
shapes
,
var_types
,
need_check_feed
)
{}
};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
...
...
@@ -134,6 +180,16 @@ class ReaderHolder {
reader_
->
Start
();
}
const
std
::
vector
<
DDim
>&
Shapes
()
const
{
return
reader_
->
Shapes
();
}
const
std
::
vector
<
proto
::
VarType
::
Type
>&
VarTypes
()
const
{
return
reader_
->
VarTypes
();
}
const
std
::
vector
<
bool
>&
NeedCheckFeed
()
const
{
return
reader_
->
NeedCheckFeed
();
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
private:
...
...
paddle/fluid/framework/reader_test.cc
浏览文件 @
95ba4bd2
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ddim.h"
class
StubDecoratedReader
:
public
paddle
::
framework
::
DecoratedReader
{
public:
...
...
@@ -26,11 +27,23 @@ class StubDecoratedReader : public paddle::framework::DecoratedReader {
class
StubRootReader
:
public
paddle
::
framework
::
ReaderBase
{
public:
explicit
StubRootReader
(
const
std
::
vector
<
paddle
::
framework
::
DDim
>
&
dims
,
const
std
::
vector
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
&
var_types
,
const
std
::
vector
<
bool
>
&
need_check_feed
)
:
paddle
::
framework
::
ReaderBase
(
dims
,
var_types
,
need_check_feed
)
{}
void
ReadNextImpl
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
};
TEST
(
READER
,
decorate_chain
)
{
auto
root
=
std
::
make_shared
<
StubRootReader
>
();
paddle
::
framework
::
proto
::
VarType
::
Type
dtype
=
paddle
::
framework
::
proto
::
VarType
::
FP32
;
paddle
::
framework
::
DDim
dim
=
paddle
::
framework
::
make_ddim
({
5
,
7
});
std
::
vector
<
paddle
::
framework
::
DDim
>
init_dims
(
4
,
dim
);
std
::
vector
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
init_types
(
4
,
dtype
);
std
::
vector
<
bool
>
init_need_check
(
4
,
true
);
auto
root
=
std
::
make_shared
<
StubRootReader
>
(
init_dims
,
init_types
,
init_need_check
);
auto
end_point1
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
auto
end_point2
=
...
...
@@ -49,4 +62,10 @@ TEST(READER, decorate_chain) {
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
3U
);
}
{
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
2U
);
}
{
ASSERT_EQ
(
end_point1
->
Shapes
(),
init_dims
);
ASSERT_EQ
(
end_point1
->
VarTypes
(),
init_types
);
ASSERT_EQ
(
end_point1
->
NeedCheckFeed
(),
init_need_check
);
}
}
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
95ba4bd2
...
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
...
...
@@ -39,7 +41,38 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto
*
queue_holder
=
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
out
->
Reset
(
std
::
make_shared
<
PyReader
>
(
queue_holder
->
GetQueue
()));
/* Coverting shape_concat and ranks into DDim of each data.
shape_concat and ranks are shapes and shape ranks of each data.E.g.
shape_concat = [2,3,4,5,6], ranks = [3,2] means two data whose shapes are
[2,3,4] and [5,6] respectively. */
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
int
shape_start_index
=
0
;
std
::
vector
<
framework
::
DDim
>
dims
;
for
(
size_t
i
=
0
;
i
<
ranks
.
size
();
++
i
)
{
int
shape_end_index
=
shape_start_index
+
ranks
[
i
];
auto
shape
=
std
::
vector
<
int
>
(
shape_concat
.
begin
()
+
shape_start_index
,
shape_concat
.
begin
()
+
shape_end_index
);
dims
.
push_back
(
framework
::
make_ddim
(
shape
));
shape_start_index
=
shape_end_index
;
}
// Converts VarType from int to enum
auto
&
dtype_int
=
Attr
<
std
::
vector
<
int
>>
(
"dtypes"
);
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
var_types
;
for
(
size_t
i
=
0
;
i
<
dtype_int
.
size
();
++
i
)
{
var_types
.
push_back
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
dtype_int
[
i
]));
}
// Converts need_check_feed from int to bool
auto
&
need_check_feed_int
=
Attr
<
std
::
vector
<
int
>>
(
"need_check_feed"
);
std
::
vector
<
bool
>
need_check_feed
;
for
(
size_t
i
=
0
;
i
<
need_check_feed_int
.
size
();
++
i
)
{
need_check_feed
.
push_back
(
static_cast
<
bool
>
(
need_check_feed_int
[
i
]));
}
out
->
Reset
(
std
::
make_shared
<
PyReader
>
(
queue_holder
->
GetQueue
(),
dims
,
var_types
,
need_check_feed
));
}
};
...
...
paddle/fluid/operators/reader/py_reader.cc
浏览文件 @
95ba4bd2
...
...
@@ -19,8 +19,12 @@ namespace paddle {
namespace
operators
{
namespace
reader
{
PyReader
::
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
)
:
framework
::
FileReader
()
{
PyReader
::
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
var_types
,
const
std
::
vector
<
bool
>&
need_check_feed
)
:
framework
::
FileReader
(
dims
,
var_types
,
need_check_feed
)
{
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
queue_
=
queue
;
}
...
...
paddle/fluid/operators/reader/py_reader.h
浏览文件 @
95ba4bd2
...
...
@@ -26,7 +26,11 @@ namespace reader {
class
PyReader
:
public
framework
::
FileReader
{
public:
explicit
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
);
explicit
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
var_types
,
const
std
::
vector
<
bool
>&
need_check_feed
);
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
...
...
paddle/fluid/operators/reader/read_op.cc
浏览文件 @
95ba4bd2
//
Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
...
...
@@ -20,6 +21,26 @@
namespace
paddle
{
namespace
operators
{
// Returns true if the two dimensions are compatible.
// A dimension is compatible with the other if:
// 1. The length of the dimensions are same.
// 2. Each non-negative number of the two dimentions are same.
// 3. For negative number in a dimention, it means unknown so it is compatible
// with any number.
bool
DimensionIsCompatibleWith
(
const
framework
::
DDim
&
first
,
const
framework
::
DDim
&
second
)
{
int
dim_size
=
first
.
size
();
if
(
dim_size
!=
second
.
size
())
{
return
false
;
}
for
(
int
i
=
0
;
i
<
dim_size
;
++
i
)
{
if
(
first
[
i
]
>=
0
&&
second
[
i
]
>=
0
&&
first
[
i
]
!=
second
[
i
])
{
return
false
;
}
}
return
true
;
}
class
ReadInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
...
...
@@ -89,10 +110,32 @@ class ReadOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"throw_eof_exp"
;
PADDLE_THROW_EOF
();
}
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
(),
"input size and output size of read_op do not match"
);
const
std
::
vector
<
framework
::
DDim
>&
shapes
=
reader
->
Shapes
();
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
var_types
=
reader
->
VarTypes
();
const
std
::
vector
<
bool
>&
need_check_feed
=
reader
->
NeedCheckFeed
();
PADDLE_ENFORCE_EQ
(
out_arg_names
.
size
(),
need_check_feed
.
size
(),
"output size of read_op and the number of feeded "
"variables of reader do not match"
);
for
(
size_t
i
=
0
;
i
<
out_arg_names
.
size
();
++
i
)
{
auto
*
out
=
scope
.
FindVar
(
out_arg_names
[
i
])
->
GetMutable
<
framework
::
LoDTensor
>
();
if
(
need_check_feed
[
i
])
{
auto
in_dims
=
ins
[
i
].
dims
();
PADDLE_ENFORCE_EQ
(
DimensionIsCompatibleWith
(
shapes
[
i
],
in_dims
),
true
,
"The feeded Variable %s should have dimensions = %d, "
"shape = [%s], but received feeded shape [%s]"
,
out_arg_names
[
i
],
shapes
[
i
].
size
(),
shapes
[
i
],
in_dims
);
PADDLE_ENFORCE_EQ
(
ins
[
i
].
type
(),
var_types
[
i
],
"The data type of feeded Variable %s must be %s, but received %s"
,
out_arg_names
[
i
],
var_types
[
i
],
ins
[
i
].
type
());
}
out
->
ShareDataWith
(
ins
[
i
]);
out
->
set_lod
(
ins
[
i
].
lod
());
}
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
95ba4bd2
...
...
@@ -50,6 +50,10 @@ void FileReaderMakerBase::Make() {
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."
);
AddAttr
<
std
::
vector
<
int
>>
(
"lod_levels"
,
"The LoD levels of each data."
);
AddAttr
<
std
::
vector
<
int
>>
(
"dtypes"
,
"The int value of enum dtypes of each data."
);
AddAttr
<
std
::
vector
<
int
>>
(
"need_check_feed"
,
"Whether to check shape and dtypes of input"
);
AddAttr
<
bool
>
(
"use_data_config"
,
"Use the config of all datas like shape_concat/ranks/lod_levels"
)
...
...
@@ -77,6 +81,17 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d)."
,
lod_levels
.
size
(),
shapes
.
size
());
const
auto
dtypes
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dtypes"
);
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
shapes
.
size
(),
"The number of 'dtypes'(%d) doesn't match the number of 'shapes'(%d)."
,
dtypes
.
size
(),
shapes
.
size
());
const
auto
need_check_feed
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"need_check_feed"
);
PADDLE_ENFORCE_EQ
(
need_check_feed
.
size
(),
shapes
.
size
(),
"The number of 'need_check_feed'(%d) doesn't match the "
"number of 'shapes'(%d)."
,
need_check_feed
.
size
(),
shapes
.
size
());
framework
::
VarDesc
*
reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
reader
->
SetLoDLevels
(
lod_levels
);
...
...
paddle/fluid/pybind/reader_py.cc
浏览文件 @
95ba4bd2
...
...
@@ -20,6 +20,7 @@
#include <utility>
#include <vector>
#include "Python.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h"
...
...
@@ -40,12 +41,19 @@ class MultiDeviceFeedReader {
MultiDeviceFeedReader
(
const
std
::
shared_ptr
<
operators
::
reader
::
LoDTensorBlockingQueue
>
&
queue
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
vector
<
int
>>
&
shapes
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
dtypes
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
bool
use_double_buffer
)
:
queue_
(
queue
),
names_
(
names
),
pool_
(
new
::
ThreadPool
(
dst_places
.
size
()))
{
std
::
vector
<
framework
::
DDim
>
dims
;
for
(
auto
&
shape
:
shapes
)
{
dims
.
push_back
(
framework
::
make_ddim
(
shape
));
}
std
::
shared_ptr
<
framework
::
ReaderBase
>
reader
(
new
operators
::
reader
::
PyReader
(
queue
));
new
operators
::
reader
::
PyReader
(
queue
,
dims
,
dtypes
,
need_check_feed
));
readers_
.
reserve
(
dst_places
.
size
());
for
(
auto
&
p
:
dst_places
)
{
...
...
@@ -206,9 +214,13 @@ void BindReader(py::module *module) {
[](
const
std
::
shared_ptr
<
operators
::
reader
::
LoDTensorBlockingQueue
>
&
queue
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
vector
<
int
>>
&
shapes
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
dtypes
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
bool
use_double_buffer
)
{
return
new
MultiDeviceFeedReader
(
queue
,
names
,
dst_places
,
return
new
MultiDeviceFeedReader
(
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
use_double_buffer
);
},
py
::
return_value_policy
::
take_ownership
);
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
95ba4bd2
...
...
@@ -385,6 +385,7 @@ def _py_reader(capacity,
shape_concat
=
[]
ranks
=
[]
shapes
=
[]
need_check_feed
=
[]
for
feed_data
in
feed_list
:
dtypes
.
append
(
feed_data
.
dtype
)
...
...
@@ -392,8 +393,10 @@ def _py_reader(capacity,
ranks
.
append
(
len
(
feed_data
.
shape
))
shapes
.
append
(
feed_data
.
shape
)
lod_levels
.
append
(
feed_data
.
lod_level
)
need_check_feed
.
append
(
int
(
feed_data
.
desc
.
need_check_feed
()))
else
:
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
need_check_feed
=
[
0
for
dt
in
dtypes
]
shape_concat
=
[]
ranks
=
[]
...
...
@@ -403,7 +406,7 @@ def _py_reader(capacity,
if
lod_levels
is
None
:
lod_levels
=
[
0
]
*
len
(
shapes
)
dtype_int
=
[
int
(
t
)
for
t
in
dtypes
]
if
name
is
None
:
queue_name
=
unique_name
(
'lod_tensor_blocking_queue'
)
reader_name
=
unique_name
(
'create_py_reader'
)
...
...
@@ -425,6 +428,8 @@ def _py_reader(capacity,
attrs
=
{
'shape_concat'
:
shape_concat
,
'lod_levels'
:
lod_levels
,
'dtypes'
:
dtype_int
,
'need_check_feed'
:
need_check_feed
,
'ranks'
:
ranks
})
...
...
python/paddle/fluid/reader.py
浏览文件 @
95ba4bd2
...
...
@@ -331,7 +331,7 @@ class GeneratorLoader(DataLoaderBase):
self
.
_init_non_iterable
()
def
_wait_thread_ends
(
self
):
# Get self._thread first to prevent data race, because __thread_main__
# Get self._thread first to prevent data race, because __thread_main__
# would set self._thread be None at the end
thread
=
self
.
_thread
if
thread
is
not
None
and
self
.
_iterable
:
...
...
@@ -342,12 +342,21 @@ class GeneratorLoader(DataLoaderBase):
self
.
_wait_thread_ends
()
if
in_dygraph_mode
():
self
.
_var_names
=
[]
self
.
_shapes
=
[]
self
.
_dtypes
=
[]
self
.
_need_check_feed
=
[]
else
:
self
.
_var_names
=
[
v
.
name
for
v
in
self
.
_feed_list
]
self
.
_shapes
=
[
v
.
shape
for
v
in
self
.
_feed_list
]
self
.
_dtypes
=
[
v
.
dtype
for
v
in
self
.
_feed_list
]
self
.
_need_check_feed
=
[
v
.
desc
.
need_check_feed
()
for
v
in
self
.
_feed_list
]
self
.
_queue
=
core
.
init_lod_tensor_blocking_queue
(
core
.
Variable
(),
self
.
_capacity
)
self
.
_reader
=
core
.
create_py_reader
(
self
.
queue
,
self
.
_var_names
,
self
.
_places
,
self
.
_use_double_buffer
)
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
)
def
_init_non_iterable
(
self
):
lod_levels
=
[]
...
...
@@ -355,6 +364,7 @@ class GeneratorLoader(DataLoaderBase):
shape_concat
=
[]
ranks
=
[]
shapes
=
[]
need_check_feed
=
[]
for
feed_data
in
self
.
_feed_list
:
dtypes
.
append
(
feed_data
.
dtype
)
...
...
@@ -362,6 +372,7 @@ class GeneratorLoader(DataLoaderBase):
ranks
.
append
(
len
(
feed_data
.
shape
))
shapes
.
append
(
feed_data
.
shape
)
lod_levels
.
append
(
feed_data
.
lod_level
)
need_check_feed
.
append
(
int
(
feed_data
.
desc
.
need_check_feed
()))
queue_name
=
data_loader_unique_name_generator
(
'lod_tensor_blocking_queue'
)
...
...
@@ -374,6 +385,7 @@ class GeneratorLoader(DataLoaderBase):
startup_blk
=
default_startup_program
().
current_block
()
startup_var
=
startup_blk
.
create_var
(
name
=
reader_name
)
dtype_int
=
[
int
(
t
)
for
t
in
dtypes
]
startup_blk
.
append_op
(
type
=
'create_py_reader'
,
inputs
=
{
'blocking_queue'
:
[
queue_name
]},
...
...
@@ -381,6 +393,8 @@ class GeneratorLoader(DataLoaderBase):
attrs
=
{
'shape_concat'
:
shape_concat
,
'lod_levels'
:
lod_levels
,
'dtypes'
:
dtype_int
,
'need_check_feed'
:
need_check_feed
,
'ranks'
:
ranks
})
...
...
python/paddle/fluid/tests/unittests/test_py_reader_pin_memory.py
浏览文件 @
95ba4bd2
...
...
@@ -73,8 +73,7 @@ class TestPyReader(unittest.TestCase):
for
_
in
range
(
10
):
sample
=
np
.
random
.
uniform
(
low
=
0
,
high
=
1
,
size
=
[
3
,
2
,
1
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
low
=
0
,
high
=
10
,
size
=
[
1
]).
astype
(
"int64"
)
label
=
np
.
random
.
randint
(
low
=
0
,
high
=
10
,
dtype
=
"int64"
)
self
.
inputs
.
append
((
sample
,
label
))
self
.
input_tensors
=
[]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录