Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
48f213e5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
48f213e5
编写于
3月 14, 2018
作者:
Y
Yu Yang
提交者:
GitHub
3月 14, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8991 from reyoung/feature/shuffle_reader
Feature/shuffle reader
上级
881c5227
127b371e
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
270 addition
and
127 deletion
+270
-127
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+2
-18
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+15
-7
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+20
-31
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+88
-23
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+3
-2
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+12
-9
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+44
-31
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+5
-0
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+26
-0
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+30
-1
python/paddle/fluid/recordio_writer.py
python/paddle/fluid/recordio_writer.py
+3
-0
python/paddle/fluid/tests/unittests/test_recordio_reader.py
python/paddle/fluid/tests/unittests/test_recordio_reader.py
+22
-5
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
48f213e5
...
...
@@ -445,15 +445,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
ReaderHolder
>
())
{
return
var
->
Get
<
ReaderHolder
>
().
shapes
();
}
else
{
PADDLE_THROW
(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s."
,
name
,
var
->
Type
().
name
());
}
PADDLE_THROW
(
"Only compile time support this method"
);
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
...
...
@@ -470,15 +462,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
ReaderHolder
>
())
{
var
->
GetMutable
<
ReaderHolder
>
()
->
set_shapes
(
dims
);
}
else
{
PADDLE_THROW
(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s."
,
name
,
var
->
Type
().
name
());
}
PADDLE_THROW
(
"Only compile time support this method"
);
}
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
override
{
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
48f213e5
...
...
@@ -16,14 +16,22 @@
namespace
paddle
{
namespace
framework
{
ReaderBase
::~
ReaderBase
()
{}
DDim
ReaderBase
::
shape
(
size_t
idx
)
const
{
PADDLE_ENFORCE_LT
(
idx
,
shapes_
.
size
(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements."
,
idx
,
shapes_
.
size
());
return
shapes_
[
idx
];
}
FileReader
::
FileReader
(
const
std
::
vector
<
DDim
>
&
dims
)
:
dims_
(
dims
)
{}
void
FileReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
ReadNextImpl
(
out
);
PADDLE_ENFORCE_EQ
(
out
->
size
(),
dims_
.
size
());
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
auto
&
actual
=
out
->
at
(
i
).
dims
();
auto
&
expect
=
dims_
[
i
];
PADDLE_ENFORCE_EQ
(
actual
.
size
(),
expect
.
size
());
for
(
int
j
=
0
;
j
<
actual
.
size
();
++
j
)
{
PADDLE_ENFORCE
(
actual
[
i
]
==
expect
[
i
]
||
expect
[
i
]
==
-
1
);
}
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
48f213e5
...
...
@@ -16,40 +16,29 @@
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace
paddle
{
namespace
framework
{
class
ReaderBase
{
public:
explicit
ReaderBase
(
const
std
::
vector
<
DDim
>&
shapes
)
:
shapes_
(
shapes
)
{
PADDLE_ENFORCE
(
!
shapes_
.
empty
());
}
virtual
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
virtual
void
ReInit
()
=
0
;
DDim
shape
(
size_t
idx
)
const
;
std
::
vector
<
DDim
>
shapes
()
const
{
return
shapes_
;
}
void
set_shapes
(
const
std
::
vector
<
DDim
>&
shapes
)
{
shapes_
=
shapes
;
}
virtual
bool
HasNext
()
const
=
0
;
virtual
~
ReaderBase
()
{}
protected:
std
::
vector
<
DDim
>
shapes_
;
};
class
FileReader
:
public
ReaderBase
{
public:
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
shapes
)
:
ReaderBase
(
shapes
)
{}
virtual
~
ReaderBase
();
};
class
DecoratedReader
:
public
ReaderBase
{
public:
explicit
DecoratedReader
(
ReaderBase
*
reader
)
:
ReaderBase
(
reader
->
shapes
()),
reader_
(
reader
)
{
explicit
DecoratedReader
(
ReaderBase
*
reader
)
:
ReaderBase
(),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
...
...
@@ -61,6 +50,19 @@ class DecoratedReader : public ReaderBase {
ReaderBase
*
reader_
;
};
class
FileReader
:
public
ReaderBase
{
public:
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
dims
);
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
override
;
protected:
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
private:
std
::
vector
<
DDim
>
dims_
;
};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class
ReaderHolder
{
...
...
@@ -78,19 +80,6 @@ class ReaderHolder {
reader_
->
ReInit
();
}
DDim
shape
(
size_t
idx
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
return
reader_
->
shape
(
idx
);
}
std
::
vector
<
DDim
>
shapes
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
return
reader_
->
shapes
();
}
void
set_shapes
(
const
std
::
vector
<
DDim
>&
shapes
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
set_shapes
(
shapes
);
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
private:
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
48f213e5
...
...
@@ -24,11 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;
class
DoubleBufferReader
:
public
framework
::
DecoratedReader
{
public:
explicit
DoubleBufferReader
(
ReaderBase
*
reader
)
:
DecoratedReader
(
reader
),
buffer_
(
framework
::
MakeChannel
<
std
::
vector
<
framework
::
LoDTensor
>>
(
kDoubleBufferSize
))
{
std
::
thread
prefetch
(
&
DoubleBufferReader
::
PrefetchThreadFunc
,
this
);
struct
Item
{
Item
()
:
ctx_
(
nullptr
)
{}
std
::
vector
<
framework
::
LoDTensor
>
payloads_
;
platform
::
DeviceContext
*
ctx_
;
};
explicit
DoubleBufferReader
(
ReaderBase
*
reader
,
platform
::
Place
target_place
=
platform
::
CPUPlace
())
:
DecoratedReader
(
reader
),
place_
(
target_place
)
{
for
(
size_t
i
=
0
;
i
<
kDoubleBufferSize
;
++
i
)
{
if
(
platform
::
is_gpu_place
(
place_
))
{
#ifdef PADDLE_WITH_CUDA
ctxs_
.
emplace_back
(
new
platform
::
CUDADeviceContext
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
)));
#endif
}
}
start_thread
();
}
void
start_thread
()
{
buffer_
=
framework
::
MakeChannel
<
Item
>
(
kDoubleBufferSize
);
std
::
thread
prefetch
([
this
]
{
PrefetchThreadFunc
();
});
prefetch
.
detach
();
}
...
...
@@ -42,7 +62,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
private:
void
PrefetchThreadFunc
();
framework
::
Channel
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
framework
::
Channel
<
Item
>*
buffer_
;
platform
::
Place
place_
;
std
::
vector
<
std
::
unique_ptr
<
platform
::
DeviceContext
>>
ctxs_
;
mutable
Item
local_buffer_
;
};
class
CreateDoubleBufferReaderOp
:
public
framework
::
OperatorBase
{
...
...
@@ -56,7 +79,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
DoubleBufferReader
(
underlying_reader
.
Get
()));
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
platform
::
Place
place
;
if
(
place_str
==
"CPU"
)
{
place
=
platform
::
CPUPlace
();
}
else
{
std
::
istringstream
sin
(
place_str
);
sin
.
seekg
(
std
::
string
(
"CUDA:"
).
size
(),
std
::
ios
::
beg
);
size_t
num
;
sin
>>
num
;
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
}
out
->
Reset
(
new
DoubleBufferReader
(
underlying_reader
.
Get
(),
place
));
}
};
...
...
@@ -71,44 +107,73 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
It launches another thread to execute the 'underlying reader' asynchronously,
which prevents reading process from blocking subsequent training.
)DOC"
);
std
::
unordered_set
<
std
::
string
>
enum_range
;
constexpr
size_t
kMaxCUDADevs
=
128
;
for
(
size_t
i
=
0
;
i
<
kMaxCUDADevs
;
++
i
)
{
enum_range
.
insert
(
string
::
Sprintf
(
"CUDA:%d"
,
i
));
}
enum_range
.
insert
(
"CPU"
);
AddAttr
<
std
::
string
>
(
"place"
,
"The double buffer place, default is CPU"
)
.
SetDefault
(
"CPU"
)
.
InEnum
({
enum_range
});
}
};
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
out
->
clear
();
buffer_
->
Receive
(
out
);
if
(
local_buffer_
.
payloads_
.
empty
())
{
buffer_
->
Receive
(
&
local_buffer_
);
}
*
out
=
local_buffer_
.
payloads_
;
local_buffer_
.
payloads_
.
clear
();
if
(
local_buffer_
.
ctx_
)
{
local_buffer_
.
ctx_
->
Wait
();
}
}
void
DoubleBufferReader
::
ReInit
()
{
reader_
->
ReInit
();
buffer_
->
Close
();
// The existing prefetch thread will terminate for the buffer_ is closed.
buffer_
=
framework
::
MakeChannel
<
std
::
vector
<
framework
::
LoDTensor
>>
(
kDoubleBufferSize
);
std
::
thread
prefetch
(
&
DoubleBufferReader
::
PrefetchThreadFunc
,
this
);
prefetch
.
detach
();
start_thread
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
while
(
true
)
{
std
::
vector
<
framework
::
LoDTensor
>
batch
;
reader_
->
ReadNext
(
&
batch
);
if
(
batch
.
empty
())
{
// EOF
buffer_
->
Close
();
VLOG
(
5
)
<<
"Reached the end of the file. The prefetch thread terminates."
;
break
;
size_t
gpu_ctx_offset
=
0
;
while
(
reader_
->
HasNext
())
{
Item
batch
;
reader_
->
ReadNext
(
&
batch
.
payloads_
);
if
(
platform
::
is_gpu_place
(
place_
))
{
std
::
vector
<
framework
::
LoDTensor
>
gpu_batch
;
auto
&
gpu_ctx
=
this
->
ctxs_
[
gpu_ctx_offset
++
];
gpu_ctx_offset
%=
this
->
ctxs_
.
size
();
gpu_batch
.
resize
(
batch
.
payloads_
.
size
());
for
(
size_t
i
=
0
;
i
<
batch
.
payloads_
.
size
();
++
i
)
{
framework
::
TensorCopy
(
batch
.
payloads_
[
i
],
place_
,
*
gpu_ctx
,
&
gpu_batch
[
i
]);
gpu_batch
[
i
].
set_lod
(
batch
.
payloads_
[
i
].
lod
());
}
batch
.
ctx_
=
gpu_ctx
.
get
();
std
::
swap
(
gpu_batch
,
batch
.
payloads_
);
}
if
(
!
buffer_
->
Send
(
&
batch
))
{
VLOG
(
5
)
<<
"WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates."
;
break
;
}
}
buffer_
->
Close
();
}
bool
DoubleBufferReader
::
HasNext
()
const
{
PADDLE_THROW
(
"Not Implemented"
);
}
bool
DoubleBufferReader
::
HasNext
()
const
{
if
(
local_buffer_
.
payloads_
.
empty
())
{
bool
ok
=
buffer_
->
Receive
(
&
local_buffer_
);
return
ok
;
}
else
{
return
true
;
}
}
}
// namespace reader
}
// namespace operators
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
48f213e5
...
...
@@ -19,11 +19,11 @@ namespace operators {
namespace
reader
{
template
<
typename
T
>
class
RandomDataGenerator
:
public
framework
::
FileReader
{
class
RandomDataGenerator
:
public
framework
::
ReaderBase
{
public:
RandomDataGenerator
(
const
std
::
vector
<
framework
::
DDim
>&
shapes
,
float
min
,
float
max
)
:
FileReader
(
shapes
),
min_
(
min
),
max_
(
max
)
{
:
framework
::
ReaderBase
(),
min_
(
min
),
max_
(
max
),
shapes_
(
shapes
)
{
PADDLE_ENFORCE_LE
(
min
,
max
,
"'min' shouldn't be greater than 'max'.(%f vs %f)"
,
min
,
max
);
unsigned
int
seed
=
std
::
random_device
()();
...
...
@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
float
max_
;
std
::
minstd_rand
engine_
;
std
::
uniform_real_distribution
<
float
>
dist_
;
std
::
vector
<
framework
::
DDim
>
shapes_
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
48f213e5
...
...
@@ -20,21 +20,22 @@ namespace operators {
namespace
reader
{
class
RecordIOFileReader
:
public
framework
::
FileReader
{
public:
RecordIOFileReader
(
const
std
::
string
&
filename
,
const
std
::
vector
<
framework
::
DDim
>&
shape
s
)
:
FileReader
(
shape
s
),
explicit
RecordIOFileReader
(
const
std
::
string
&
filename
,
const
std
::
vector
<
framework
::
DDim
>&
dim
s
)
:
FileReader
(
dim
s
),
scanner_
(
filename
),
dev_ctx_
(
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()))
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
}
bool
HasNext
()
const
override
{
return
scanner_
.
HasNext
();
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
}
private:
recordio
::
Scanner
scanner_
;
const
platform
::
DeviceContext
&
dev_ctx_
;
...
...
@@ -54,12 +55,12 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
int
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RecordIOFileReader
(
filename
,
shapes
));
out
->
Reset
(
new
RecordIOFileReader
(
filename
,
RestoreShapes
(
shape_concat
,
ranks
)));
}
};
...
...
@@ -85,3 +86,5 @@ namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR
(
create_recordio_file_reader
,
reader
::
CreateRecordIOReaderOp
,
reader
::
CreateRecordIOReaderOpMaker
);
REGISTER_FILE_READER
(
recordio
,
reader
::
RecordIOFileReader
);
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
48f213e5
...
...
@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include "glog/logging.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
...
...
@@ -20,43 +23,53 @@ namespace reader {
class
ShuffleReader
:
public
framework
::
DecoratedReader
{
public:
ShuffleReader
(
ReaderBase
*
reader
,
int
buffer_size
)
:
DecoratedReader
(
reader
),
buffer_size_
(
buffer_size
),
iteration_pos_
(
0
)
{
buffer_
.
reserve
(
buffer_size
);
ShuffleReader
(
ReaderBase
*
reader
,
size_t
buffer_size
,
size_t
seed
=
0
)
:
DecoratedReader
(
reader
),
buffer_size_
(
buffer_size
),
seed_
(
seed
)
{
VLOG
(
10
)
<<
"Create shuffle reader of "
<<
reader_
;
if
(
seed_
==
0
)
{
std
::
random_device
device
;
seed_
=
device
();
}
ReadIntoBuffers
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
ReadIntoBuffers
();
}
*
out
=
buffer_
[
iteration_pos_
++
];
}
private:
int
buffer_size_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
size_t
iteration_pos_
;
};
bool
HasNext
()
const
override
{
return
iteration_pos_
<
buffer_
.
size
()
||
reader_
->
HasNext
();
}
void
ShuffleReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
// Reload buffer with new data
private:
void
ReadIntoBuffers
()
{
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
for
(
int
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
buffer_
.
push_back
(
std
::
vector
<
framework
::
LoDTensor
>
());
reader_
->
ReadNext
(
&
buffer_
.
back
());
if
(
buffer_
.
back
().
empty
())
{
buffer_
.
pop_back
();
iteration_pos_
=
0
;
PADDLE_ENFORCE
(
reader_
->
HasNext
());
for
(
size_t
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
!
reader_
->
HasNext
())
{
break
;
}
buffer_
.
emplace_back
();
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
s
td
::
random_shuffle
(
buffer_
.
begin
(),
buffer_
.
end
())
;
iteration_pos_
=
0
;
std
::
mt19937
g
(
seed_
);
std
::
shuffle
(
buffer_
.
begin
(),
buffer_
.
end
(),
g
);
s
eed_
=
g
();
// update seed_
;
VLOG
(
10
)
<<
"random buffer size = "
<<
buffer_
.
size
()
;
}
out
->
clear
();
if
(
!
buffer_
.
empty
())
{
std
::
swap
(
*
out
,
buffer_
[
iteration_pos_
++
]);
}
// if buffer_ is empty, the 'out' will return as an empty vector.
}
size_t
buffer_size_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
size_t
iteration_pos_
;
size_t
seed_
;
};
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
public:
...
...
@@ -67,10 +80,10 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"buffer_size"
)));
auto
&
var
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)));
var
.
GetMutable
<
framework
::
ReaderHolder
>
()
->
Reset
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
)
)));
}
};
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
48f213e5
...
...
@@ -31,6 +31,11 @@ std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
return
res
;
}
std
::
unordered_map
<
std
::
string
,
FileReaderCreator
>&
FileReaderRegistry
()
{
static
std
::
unordered_map
<
std
::
string
,
FileReaderCreator
>
regs
;
return
regs
;
}
FileReaderMakerBase
::
FileReaderMakerBase
(
framework
::
OpProtoAndCheckerMaker
::
OpProto
*
op_proto
,
framework
::
OpAttrChecker
*
op_checker
)
...
...
paddle/fluid/operators/reader/reader_op_registry.h
浏览文件 @
48f213e5
...
...
@@ -21,6 +21,20 @@ namespace paddle {
namespace
operators
{
namespace
reader
{
using
FileReaderCreator
=
std
::
function
<
framework
::
ReaderBase
*
(
const
std
::
string
&
,
const
std
::
vector
<
framework
::
DDim
>&
)
>
;
std
::
unordered_map
<
std
::
string
,
FileReaderCreator
>&
FileReaderRegistry
();
template
<
typename
Reader
>
int
RegisterFileReader
(
const
std
::
string
&
filetype
)
{
FileReaderRegistry
()[
filetype
]
=
[](
const
std
::
string
&
fn
,
const
std
::
vector
<
paddle
::
framework
::
DDim
>&
dim
)
{
return
new
Reader
(
fn
,
dim
);
};
return
0
;
}
extern
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
);
...
...
@@ -73,3 +87,15 @@ class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
paddle::operators::reader::DecoratedReaderInferShape, \
paddle::framework::EmptyGradOpMaker, \
paddle::operators::reader::DecoratedReaderInferVarType)
#define REGISTER_FILE_READER(_filetype, _reader) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
_reg_file_reader_##_filetype, \
"Must use REGISTER_FILE_READER in global namespace"); \
int TouchFileReader##_filetype() { return 0; } \
int _reg_file_reader_entry_##filetype = \
paddle::operators::reader::RegisterFileReader<_reader>(#_filetype)
#define USE_FILE_READER(filetype) \
extern int TouchFileReader##filetype(); \
static int _use_##filetype = TouchFileReader##filetype()
python/paddle/fluid/layers/io.py
浏览文件 @
48f213e5
...
...
@@ -21,7 +21,7 @@ from ..executor import global_scope
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'read_file'
'read_file'
,
'create_shuffle_reader'
,
'create_double_buffer_reader'
]
...
...
@@ -245,6 +245,8 @@ def monkey_patch_reader_methods(reader):
reader
.
eof
=
eof
reader
.
reset
=
reset
reader
.
stop_gradient
=
True
reader
.
persistable
=
True
return
reader
...
...
@@ -285,6 +287,33 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var
)
def
__create_decorated_reader__
(
op_type
,
reader
,
attrs
):
var_name
=
unique_name
(
op_type
)
startup_blk
=
default_startup_program
().
current_block
()
startup_var
=
startup_blk
.
create_var
(
name
=
var_name
)
startup_blk
.
append_op
(
type
=
op_type
,
inputs
=
{
'UnderlyingReader'
:
reader
},
outputs
=
{
'Out'
:
[
startup_var
]},
attrs
=
attrs
)
startup_var
.
persistable
=
True
return
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
def
create_shuffle_reader
(
reader
,
buffer_size
):
return
__create_decorated_reader__
(
'create_shuffle_reader'
,
reader
,
{
'buffer_size'
:
int
(
buffer_size
)})
def
create_double_buffer_reader
(
reader
,
place
=
None
):
attrs
=
dict
()
if
place
is
not
None
:
attrs
[
'place'
]
=
str
(
place
).
upper
()
return
__create_decorated_reader__
(
'create_double_buffer_reader'
,
reader
,
attrs
)
def
read_file
(
file_obj
):
helper
=
LayerHelper
(
'read_file'
)
out
=
[
...
...
python/paddle/fluid/recordio_writer.py
浏览文件 @
48f213e5
...
...
@@ -36,6 +36,7 @@ def convert_reader_to_recordio_file(
feed_order
=
None
):
if
feed_order
is
None
:
feed_order
=
feeder
.
feed_names
counter
=
0
with
create_recordio_writer
(
filename
,
compressor
,
max_num_records
)
as
writer
:
for
batch
in
reader_creator
():
...
...
@@ -43,3 +44,5 @@ def convert_reader_to_recordio_file(
for
each
in
feed_order
:
writer
.
append_tensor
(
res
[
each
])
writer
.
complete_append_tensor
()
counter
+=
1
return
counter
python/paddle/fluid/tests/unittests/test_recordio_reader.py
浏览文件 @
48f213e5
...
...
@@ -13,9 +13,10 @@
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
paddle.v2.dataset.mnist
as
mnist
import
paddle.v2
as
paddle
import
paddle.v2.dataset.mnist
as
mnist
class
TestRecordIO
(
unittest
.
TestCase
):
...
...
@@ -31,10 +32,10 @@ class TestRecordIO(unittest.TestCase):
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
),
],
place
=
fluid
.
CPUPlace
())
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
self
.
num_batches
=
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
'./mnist.recordio'
,
reader
,
feeder
)
def
test_main
(
self
):
def
test_main
(
self
,
decorator_callback
=
None
):
# use new program
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
data_file
=
fluid
.
layers
.
open_recordio_file
(
...
...
@@ -42,6 +43,8 @@ class TestRecordIO(unittest.TestCase):
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
if
decorator_callback
is
not
None
:
data_file
=
decorator_callback
(
data_file
)
img
,
label
=
fluid
.
layers
.
read_file
(
data_file
)
hidden
=
fluid
.
layers
.
fc
(
input
=
img
,
size
=
100
,
act
=
'tanh'
)
...
...
@@ -51,14 +54,28 @@ class TestRecordIO(unittest.TestCase):
fluid
.
optimizer
.
Adam
(
learning_rate
=
1e-3
).
minimize
(
avg_loss
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
avg_loss_np
=
[]
# train a pass
batch_id
=
0
while
not
data_file
.
eof
():
tmp
,
=
exe
.
run
(
fetch_list
=
[
avg_loss
])
avg_loss_np
.
append
(
tmp
)
batch_id
+=
1
data_file
.
reset
()
self
.
assertEqual
(
batch_id
,
self
.
num_batches
)
self
.
assertLess
(
avg_loss_np
[
-
1
],
avg_loss_np
[
0
])
def
test_shuffle_reader
(
self
):
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
create_shuffle_reader
(
reader
,
buffer_size
=
200
))
def
test_double_buffer_reader
(
self
):
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
create_double_buffer_reader
(
reader
,
place
=
'cuda:0'
if
fluid
.
core
.
is_compiled_with_cuda
()
else
'cpu'
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录