Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
90084a25
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
90084a25
编写于
4月 11, 2018
作者:
F
fengjiayi
提交者:
GitHub
4月 11, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9743 from JiayiFeng/modify_readers_to_fit_parallel_executor
Modify readers to fit the parallel executor
上级
718e1807
8c1eb869
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
390 addition
and
197 deletion
+390
-197
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+17
-15
paddle/fluid/framework/lod_tensor.h
paddle/fluid/framework/lod_tensor.h
+5
-2
paddle/fluid/framework/lod_tensor_test.cc
paddle/fluid/framework/lod_tensor_test.cc
+9
-9
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+1
-3
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+3
-1
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+3
-10
paddle/fluid/operators/read_op.cc
paddle/fluid/operators/read_op.cc
+1
-7
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+1
-0
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+27
-20
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+3
-13
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+1
-3
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+3
-7
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+11
-13
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+94
-0
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+39
-49
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
paddle/fluid/pybind/recordio.cc
paddle/fluid/pybind/recordio.cc
+1
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+137
-28
python/paddle/fluid/tests/unittests/test_multi_file_reader.py
...on/paddle/fluid/tests/unittests/test_multi_file_reader.py
+6
-2
python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
...on/paddle/fluid/tests/unittests/test_multi_pass_reader.py
+7
-3
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+12
-6
python/paddle/fluid/tests/unittests/test_recordio_reader.py
python/paddle/fluid/tests/unittests/test_recordio_reader.py
+9
-4
未找到文件。
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
90084a25
...
@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/memory.h"
...
@@ -22,11 +27,6 @@ limitations under the License. */
...
@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
...
@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
}
}
void
WriteToRecordIO
(
recordio
::
Writer
&
writer
,
void
WriteToRecordIO
(
recordio
::
Writer
*
writer
,
const
std
::
vector
<
LoDTensor
>
&
tensor
,
const
std
::
vector
<
LoDTensor
>
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
stringstream
buffer
;
std
::
stringstream
buffer
;
...
@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
...
@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
for
(
auto
&
each
:
tensor
)
{
for
(
auto
&
each
:
tensor
)
{
SerializeToStream
(
buffer
,
each
,
dev_ctx
);
SerializeToStream
(
buffer
,
each
,
dev_ctx
);
}
}
writer
.
Write
(
buffer
.
str
());
writer
->
Write
(
buffer
.
str
());
}
}
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
recordio
::
Scanner
&
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
recordio
::
Scanner
*
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
istringstream
sin
(
scanner
.
Next
());
uint32_t
sz
;
sin
.
read
(
reinterpret_cast
<
char
*>
(
&
sz
),
sizeof
(
uint32_t
));
std
::
vector
<
LoDTensor
>
result
;
std
::
vector
<
LoDTensor
>
result
;
result
.
resize
(
sz
);
if
(
scanner
->
HasNext
())
{
for
(
uint32_t
i
=
0
;
i
<
sz
;
++
i
)
{
std
::
istringstream
sin
(
scanner
->
Next
());
DeserializeFromStream
(
sin
,
&
result
[
i
],
dev_ctx
);
uint32_t
sz
;
sin
.
read
(
reinterpret_cast
<
char
*>
(
&
sz
),
sizeof
(
uint32_t
));
result
.
resize
(
sz
);
for
(
uint32_t
i
=
0
;
i
<
sz
;
++
i
)
{
DeserializeFromStream
(
sin
,
&
result
[
i
],
dev_ctx
);
}
}
}
return
result
;
return
result
;
}
}
...
...
paddle/fluid/framework/lod_tensor.h
浏览文件 @
90084a25
...
@@ -15,6 +15,9 @@ limitations under the License. */
...
@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#pragma once
#include <memory>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/host_vector.h>
...
@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
...
@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
,
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
const
platform
::
DeviceContext
&
dev_ctx
);
extern
void
WriteToRecordIO
(
recordio
::
Writer
&
writer
,
extern
void
WriteToRecordIO
(
recordio
::
Writer
*
writer
,
const
std
::
vector
<
LoDTensor
>&
tensor
,
const
std
::
vector
<
LoDTensor
>&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
const
platform
::
DeviceContext
&
dev_ctx
);
extern
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
extern
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
recordio
::
Scanner
&
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
);
recordio
::
Scanner
*
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/lod_tensor_test.cc
浏览文件 @
90084a25
...
@@ -12,17 +12,17 @@
...
@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
...
@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
{
{
recordio
::
Writer
writer
(
stream
,
recordio
::
Compressor
::
kSnappy
);
recordio
::
Writer
writer
(
stream
,
recordio
::
Compressor
::
kSnappy
);
WriteToRecordIO
(
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
&
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
&
writer
,
{
tensor
,
tensor
},
ctx
);
writer
.
Flush
();
writer
.
Flush
();
}
}
...
@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
...
@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{
{
std
::
unique_ptr
<
std
::
istream
>
stream_ptr
(
stream
);
std
::
unique_ptr
<
std
::
istream
>
stream_ptr
(
stream
);
recordio
::
Scanner
scanner
(
std
::
move
(
stream_ptr
));
recordio
::
Scanner
scanner
(
std
::
move
(
stream_ptr
));
auto
tensors
=
ReadFromRecordIO
(
scanner
,
ctx
);
auto
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
assert_tensor_ok
(
tensors
[
1
]);
tensors
=
ReadFromRecordIO
(
scanner
,
ctx
);
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
assert_tensor_ok
(
tensors
[
1
]);
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
90084a25
...
@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs(
...
@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs(
for
(
auto
&
var
:
vars
)
{
for
(
auto
&
var
:
vars
)
{
auto
*
main_var
=
main_scope
->
FindVar
(
var
);
auto
*
main_var
=
main_scope
->
FindVar
(
var
);
if
(
!
main_var
->
IsType
<
LoDTensor
>
())
{
if
(
main_var
==
nullptr
||
!
main_var
->
IsType
<
LoDTensor
>
())
{
continue
;
continue
;
}
}
auto
&
main_tensor
=
main_var
->
Get
<
LoDTensor
>
();
auto
&
main_tensor
=
main_var
->
Get
<
LoDTensor
>
();
auto
&
dims
=
main_tensor
.
dims
();
auto
&
dims
=
main_tensor
.
dims
();
if
(
paddle
::
platform
::
is_gpu_place
(
main_tensor
.
place
()))
{
if
(
paddle
::
platform
::
is_gpu_place
(
main_tensor
.
place
()))
{
size_t
numel
=
main_tensor
.
numel
();
size_t
numel
=
main_tensor
.
numel
();
ncclDataType_t
data_type
=
platform
::
ToNCCLDataType
(
main_tensor
.
type
());
ncclDataType_t
data_type
=
platform
::
ToNCCLDataType
(
main_tensor
.
type
());
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
90084a25
...
@@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
...
@@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void
FileReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
void
FileReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
ReadNextImpl
(
out
);
ReadNextImpl
(
out
);
PADDLE_ENFORCE_EQ
(
out
->
size
(),
dims_
.
size
());
if
(
out
->
empty
())
{
return
;
}
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
auto
&
actual
=
out
->
at
(
i
).
dims
();
auto
&
actual
=
out
->
at
(
i
).
dims
();
auto
&
expect
=
dims_
[
i
];
auto
&
expect
=
dims_
[
i
];
...
...
paddle/fluid/framework/reader.h
浏览文件 @
90084a25
...
@@ -14,14 +14,13 @@
...
@@ -14,14 +14,13 @@
#pragma once
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -31,8 +30,6 @@ class ReaderBase {
...
@@ -31,8 +30,6 @@ class ReaderBase {
virtual
void
ReInit
()
=
0
;
virtual
void
ReInit
()
=
0
;
virtual
bool
HasNext
()
const
=
0
;
virtual
~
ReaderBase
();
virtual
~
ReaderBase
();
};
};
...
@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
...
@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void
ReInit
()
override
{
reader_
->
ReInit
();
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
bool
HasNext
()
const
override
{
return
reader_
->
HasNext
();
}
protected:
protected:
ReaderBase
*
reader_
;
ReaderBase
*
reader_
;
};
};
...
@@ -80,8 +75,6 @@ class ReaderHolder {
...
@@ -80,8 +75,6 @@ class ReaderHolder {
reader_
->
ReInit
();
reader_
->
ReInit
();
}
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
private:
private:
std
::
unique_ptr
<
ReaderBase
>
reader_
;
std
::
unique_ptr
<
ReaderBase
>
reader_
;
};
};
...
...
paddle/fluid/operators/read_op.cc
浏览文件 @
90084a25
...
@@ -66,13 +66,7 @@ class ReadOp : public framework::OperatorBase {
...
@@ -66,13 +66,7 @@ class ReadOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
framework
::
LoDTensor
>
ins
;
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
PADDLE_ENFORCE
(
!
ins
.
empty
(),
"There is no next data."
);
reader
->
ReInit
();
reader
->
ReadNext
(
&
ins
);
PADDLE_ENFORCE
(
!
ins
.
empty
(),
"Reader can not read the next data even it has been re-initialized."
);
}
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
auto
*
out
=
auto
*
out
=
...
...
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
90084a25
...
@@ -22,5 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
...
@@ -22,5 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library
(
create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc
)
reader_library
(
create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc
)
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_threaded_reader_op SRCS create_threaded_reader_op.cc
)
# Export local libraries to parent
# Export local libraries to parent
set
(
READER_LIBRARY
${
LOCAL_READER_LIBS
}
PARENT_SCOPE
)
set
(
READER_LIBRARY
${
LOCAL_READER_LIBS
}
PARENT_SCOPE
)
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
90084a25
...
@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
...
@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher
();
StartPrefetcher
();
}
}
bool
HasNext
()
const
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
void
ReInit
()
override
;
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
private:
private:
bool
HasNext
()
const
;
void
StartPrefetcher
()
{
void
StartPrefetcher
()
{
channel_
=
framework
::
MakeChannel
<
Item
>
(
kChannelSize
);
channel_
=
framework
::
MakeChannel
<
Item
>
(
kChannelSize
);
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
...
@@ -109,7 +110,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...
@@ -109,7 +110,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
platform
::
Place
place
;
platform
::
Place
place
;
if
(
place_str
==
"CPU"
)
{
if
(
place_str
==
"AUTO"
)
{
place
=
dev_place
;
}
else
if
(
place_str
==
"CPU"
)
{
place
=
platform
::
CPUPlace
();
place
=
platform
::
CPUPlace
();
}
else
{
}
else
{
std
::
istringstream
sin
(
place_str
);
std
::
istringstream
sin
(
place_str
);
...
@@ -140,28 +143,22 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -140,28 +143,22 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
enum_range
.
insert
(
string
::
Sprintf
(
"CUDA:%d"
,
i
));
enum_range
.
insert
(
string
::
Sprintf
(
"CUDA:%d"
,
i
));
}
}
enum_range
.
insert
(
"CPU"
);
enum_range
.
insert
(
"CPU"
);
AddAttr
<
std
::
string
>
(
"place"
,
"The double buffer place, default is CPU"
)
enum_range
.
insert
(
"AUTO"
);
.
SetDefault
(
"CPU"
)
AddAttr
<
std
::
string
>
(
"place"
,
"The double buffer place"
)
.
SetDefault
(
"AUTO"
)
.
InEnum
({
enum_range
});
.
InEnum
({
enum_range
});
}
}
};
};
bool
DoubleBufferReader
::
HasNext
()
const
{
while
(
!
channel_
->
IsClosed
()
&&
!
channel_
->
CanReceive
())
{
}
return
channel_
->
CanReceive
();
}
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
out
->
clear
();
PADDLE_THROW
(
"There is no next data!"
);
if
(
HasNext
())
{
}
Item
batch
;
channel_
->
Receive
(
&
batch
);
Item
batch
;
*
out
=
batch
.
payloads_
;
channel_
->
Receive
(
&
batch
);
if
(
batch
.
ctx_
)
{
*
out
=
batch
.
payloads_
;
batch
.
ctx_
->
Wait
();
if
(
batch
.
ctx_
)
{
}
batch
.
ctx_
->
Wait
();
}
}
}
}
...
@@ -171,16 +168,26 @@ void DoubleBufferReader::ReInit() {
...
@@ -171,16 +168,26 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher
();
StartPrefetcher
();
}
}
bool
DoubleBufferReader
::
HasNext
()
const
{
while
(
!
channel_
->
IsClosed
()
&&
!
channel_
->
CanReceive
())
{
}
return
channel_
->
CanReceive
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
cpu_tensor_cache
(
kCacheSize
);
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
cpu_tensor_cache
(
kCacheSize
);
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
gpu_tensor_cache
(
kCacheSize
);
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
gpu_tensor_cache
(
kCacheSize
);
size_t
cached_tensor_id
=
0
;
size_t
cached_tensor_id
=
0
;
while
(
reader_
->
HasNext
()
)
{
while
(
true
)
{
Item
batch
;
Item
batch
;
auto
&
cpu_batch
=
cpu_tensor_cache
[
cached_tensor_id
];
auto
&
cpu_batch
=
cpu_tensor_cache
[
cached_tensor_id
];
reader_
->
ReadNext
(
&
cpu_batch
);
reader_
->
ReadNext
(
&
cpu_batch
);
if
(
cpu_batch
.
empty
())
{
// The underlying reader have no next data.
break
;
}
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
auto
&
gpu_batch
=
gpu_tensor_cache
[
cached_tensor_id
];
auto
&
gpu_batch
=
gpu_tensor_cache
[
cached_tensor_id
];
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
90084a25
...
@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
...
@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
reader_
->
ReadNext
(
out
);
reader_
->
ReadNext
(
out
);
}
if
(
out
->
empty
())
{
bool
HasNext
()
const
override
{
if
(
reader_
->
HasNext
())
{
return
true
;
}
else
{
++
pass_count_
;
++
pass_count_
;
if
(
pass_count_
>=
pass_num_
)
{
if
(
pass_count_
<
pass_num_
)
{
return
false
;
}
else
{
reader_
->
ReInit
();
reader_
->
ReInit
();
re
turn
true
;
re
ader_
->
ReadNext
(
out
)
;
}
}
}
}
}
}
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
90084a25
...
@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
...
@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
void
ReInit
()
override
{
return
;
}
void
ReInit
()
override
{
return
;
}
bool
HasNext
()
const
override
{
return
true
;
}
private:
private:
float
min_
;
float
min_
;
float
max_
;
float
max_
;
...
@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
...
@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
90084a25
...
@@ -12,8 +12,6 @@
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/scanner.h"
...
@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
...
@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
}
}
bool
HasNext
()
const
override
{
return
scanner_
.
HasNext
();
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
ThreadSafe
)
{
if
(
ThreadSafe
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mutex_
);
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
*
out
=
framework
::
ReadFromRecordIO
(
&
scanner_
,
dev_ctx_
);
}
else
{
}
else
{
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
*
out
=
framework
::
ReadFromRecordIO
(
&
scanner_
,
dev_ctx_
);
}
}
}
}
...
@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
...
@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
"shape concat's length."
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
90084a25
...
@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
...
@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
std
::
random_device
device
;
std
::
random_device
device
;
seed_
=
device
();
seed_
=
device
();
}
}
Re
adIntoBuffers
();
Re
loadBuffer
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
!
HasNext
())
{
out
->
clear
();
PADDLE_THROW
(
"There is no next data!"
);
}
if
(
iteration_pos_
>=
buffer_
.
size
())
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
ReadIntoBuffers
();
ReloadBuffer
();
if
(
buffer_
.
empty
())
{
return
;
}
}
}
*
out
=
buffer_
[
iteration_pos_
++
];
*
out
=
buffer_
[
iteration_pos_
++
];
}
}
bool
HasNext
()
const
override
{
return
iteration_pos_
<
buffer_
.
size
()
||
reader_
->
HasNext
();
}
private:
private:
void
Re
adIntoBuffers
()
{
void
Re
loadBuffer
()
{
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
buffer_
.
reserve
(
buffer_size_
);
iteration_pos_
=
0
;
iteration_pos_
=
0
;
for
(
size_t
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
!
reader_
->
HasNext
())
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader_
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
break
;
break
;
}
}
buffer_
.
emplace_back
();
buffer_
.
emplace_back
(
ins
);
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
}
std
::
mt19937
g
(
seed_
);
std
::
mt19937
g
(
seed_
);
std
::
shuffle
(
buffer_
.
begin
(),
buffer_
.
end
(),
g
);
std
::
shuffle
(
buffer_
.
begin
(),
buffer_
.
end
(),
g
);
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
0 → 100644
浏览文件 @
90084a25
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
public:
ThreadedReader
(
ReaderBase
*
reader
,
bool
safe_mode
)
:
DecoratedReader
(
reader
),
safe_mode_
(
safe_mode
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
override
{
if
(
safe_mode_
)
{
PADDLE_THROW
(
"ThreadedReader::ReInit() is disabled when 'safe_mode' is true."
);
}
VLOG
(
5
)
<<
"ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment."
;
reader_
->
ReInit
();
}
private:
bool
safe_mode_
;
std
::
mutex
mutex_
;
};
class
CreateThreadedReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)))
.
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
bool
safe_mode
=
Attr
<
bool
>
(
"safe_mode"
);
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
(),
safe_mode
));
}
};
class
CreateThreadedReaderOpMaker
:
public
DecoratedReaderMakerBase
{
public:
CreateThreadedReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
DecoratedReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
bool
>
(
"safe_mode"
,
"When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"unexpected bugs in multi-thread environment."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
reader
=
paddle
::
operators
::
reader
;
REGISTER_DECORATED_READER_OPERATOR
(
create_threaded_reader
,
reader
::
CreateThreadedReaderOp
,
reader
::
CreateThreadedReaderOpMaker
);
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
90084a25
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
...
@@ -19,38 +21,23 @@ namespace paddle {
...
@@ -19,38 +21,23 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
reader
{
namespace
reader
{
class
Multi
p
leReader
:
public
framework
::
ReaderBase
{
class
Multi
Fi
leReader
:
public
framework
::
ReaderBase
{
public:
public:
class
ThreadBufferMap
{
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
public:
const
std
::
vector
<
framework
::
DDim
>&
dims
,
size_t
thread_num
,
std
::
vector
<
framework
::
LoDTensor
>&
operator
[](
size_t
buffer_size
)
const
std
::
thread
::
id
&
thread_id
)
{
:
file_names_
(
file_names
),
dims_
(
dims
),
buffer_size_
(
buffer_size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
buffer_
[
thread_id
];
}
void
Clear
()
{
buffer_
.
clear
();
}
private:
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
thread
::
id
,
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
};
MultipleReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
size_t
thread_num
)
:
file_names_
(
file_names
),
dims_
(
dims
)
{
prefetchers_
.
resize
(
thread_num
);
prefetchers_
.
resize
(
thread_num
);
StartNewScheduler
();
StartNewScheduler
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
bool
HasNext
()
const
override
;
void
ReInit
()
override
;
void
ReInit
()
override
;
~
Multi
p
leReader
()
{
EndScheduler
();
}
~
Multi
Fi
leReader
()
{
EndScheduler
();
}
private:
private:
bool
HasNext
();
void
StartNewScheduler
();
void
StartNewScheduler
();
void
EndScheduler
();
void
EndScheduler
();
void
ScheduleThreadFunc
();
void
ScheduleThreadFunc
();
...
@@ -60,39 +47,36 @@ class MultipleReader : public framework::ReaderBase {
...
@@ -60,39 +47,36 @@ class MultipleReader : public framework::ReaderBase {
std
::
vector
<
framework
::
DDim
>
dims_
;
std
::
vector
<
framework
::
DDim
>
dims_
;
std
::
thread
scheduler_
;
std
::
thread
scheduler_
;
std
::
vector
<
std
::
thread
>
prefetchers_
;
std
::
vector
<
std
::
thread
>
prefetchers_
;
size_t
buffer_size_
;
framework
::
Channel
<
size_t
>*
waiting_file_idx_
;
framework
::
Channel
<
size_t
>*
waiting_file_idx_
;
framework
::
Channel
<
size_t
>*
available_thread_idx_
;
framework
::
Channel
<
size_t
>*
available_thread_idx_
;
framework
::
Channel
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
framework
::
Channel
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
mutable
ThreadBufferMap
thread_buffer_map_
;
};
};
void
MultipleReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
out
->
clear
();
PADDLE_THROW
(
"There is no next data!"
);
if
(
HasNext
())
{
buffer_
->
Receive
(
out
);
}
}
auto
&
thread_local_buffer
=
thread_buffer_map_
[
std
::
this_thread
::
get_id
()];
*
out
=
thread_local_buffer
;
thread_local_buffer
.
clear
();
}
bool
MultipleReader
::
HasNext
()
const
{
auto
&
thread_local_buffer
=
thread_buffer_map_
[
std
::
this_thread
::
get_id
()];
return
thread_local_buffer
.
empty
()
?
buffer_
->
Receive
(
&
thread_local_buffer
)
:
true
;
}
}
void
Multi
p
leReader
::
ReInit
()
{
void
Multi
Fi
leReader
::
ReInit
()
{
EndScheduler
();
EndScheduler
();
thread_buffer_map_
.
Clear
();
StartNewScheduler
();
StartNewScheduler
();
}
}
void
MultipleReader
::
StartNewScheduler
()
{
bool
MultiFileReader
::
HasNext
()
{
while
(
!
buffer_
->
IsClosed
()
&&
!
buffer_
->
CanReceive
())
{
}
return
buffer_
->
CanReceive
();
}
void
MultiFileReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
size_t
thread_num
=
prefetchers_
.
size
();
waiting_file_idx_
=
framework
::
MakeChannel
<
size_t
>
(
file_names_
.
size
());
waiting_file_idx_
=
framework
::
MakeChannel
<
size_t
>
(
file_names_
.
size
());
available_thread_idx_
=
framework
::
MakeChannel
<
size_t
>
(
thread_num
);
available_thread_idx_
=
framework
::
MakeChannel
<
size_t
>
(
thread_num
);
buffer_
=
buffer_
=
framework
::
MakeChannel
<
std
::
vector
<
framework
::
LoDTensor
>>
(
thread_num
);
framework
::
MakeChannel
<
std
::
vector
<
framework
::
LoDTensor
>>
(
buffer_size_
);
for
(
size_t
i
=
0
;
i
<
file_names_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
file_names_
.
size
();
++
i
)
{
waiting_file_idx_
->
Send
(
&
i
);
waiting_file_idx_
->
Send
(
&
i
);
...
@@ -105,7 +89,7 @@ void MultipleReader::StartNewScheduler() {
...
@@ -105,7 +89,7 @@ void MultipleReader::StartNewScheduler() {
scheduler_
=
std
::
thread
([
this
]
{
ScheduleThreadFunc
();
});
scheduler_
=
std
::
thread
([
this
]
{
ScheduleThreadFunc
();
});
}
}
void
Multi
p
leReader
::
EndScheduler
()
{
void
Multi
Fi
leReader
::
EndScheduler
()
{
available_thread_idx_
->
Close
();
available_thread_idx_
->
Close
();
buffer_
->
Close
();
buffer_
->
Close
();
waiting_file_idx_
->
Close
();
waiting_file_idx_
->
Close
();
...
@@ -117,8 +101,8 @@ void MultipleReader::EndScheduler() {
...
@@ -117,8 +101,8 @@ void MultipleReader::EndScheduler() {
delete
waiting_file_idx_
;
delete
waiting_file_idx_
;
}
}
void
Multi
p
leReader
::
ScheduleThreadFunc
()
{
void
Multi
Fi
leReader
::
ScheduleThreadFunc
()
{
VLOG
(
5
)
<<
"Multi
p
leReader schedule thread starts."
;
VLOG
(
5
)
<<
"Multi
Fi
leReader schedule thread starts."
;
size_t
completed_thread_num
=
0
;
size_t
completed_thread_num
=
0
;
size_t
thread_idx
;
size_t
thread_idx
;
while
(
available_thread_idx_
->
Receive
(
&
thread_idx
))
{
while
(
available_thread_idx_
->
Receive
(
&
thread_idx
))
{
...
@@ -150,17 +134,20 @@ void MultipleReader::ScheduleThreadFunc() {
...
@@ -150,17 +134,20 @@ void MultipleReader::ScheduleThreadFunc() {
p
.
join
();
p
.
join
();
}
}
}
}
VLOG
(
5
)
<<
"Multi
p
leReader schedule thread terminates."
;
VLOG
(
5
)
<<
"Multi
Fi
leReader schedule thread terminates."
;
}
}
void
Multi
p
leReader
::
PrefetchThreadFunc
(
std
::
string
file_name
,
void
Multi
Fi
leReader
::
PrefetchThreadFunc
(
std
::
string
file_name
,
size_t
thread_idx
)
{
size_t
thread_idx
)
{
VLOG
(
5
)
<<
"The prefetch thread of file '"
<<
file_name
<<
"' starts."
;
VLOG
(
5
)
<<
"The prefetch thread of file '"
<<
file_name
<<
"' starts."
;
std
::
unique_ptr
<
framework
::
ReaderBase
>
reader
=
std
::
unique_ptr
<
framework
::
ReaderBase
>
reader
=
CreateReaderByFileName
(
file_name
,
dims_
);
CreateReaderByFileName
(
file_name
,
dims_
);
while
(
reader
->
HasNext
()
)
{
while
(
true
)
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
break
;
}
try
{
try
{
buffer_
->
Send
(
&
ins
);
buffer_
->
Send
(
&
ins
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
...
@@ -197,11 +184,13 @@ class OpenFilesOp : public framework::OperatorBase {
...
@@ -197,11 +184,13 @@ class OpenFilesOp : public framework::OperatorBase {
const
auto
&
file_names
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"file_names"
);
const
auto
&
file_names
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"file_names"
);
PADDLE_ENFORCE
(
!
file_names
.
empty
(),
"No file to be read!"
);
PADDLE_ENFORCE
(
!
file_names
.
empty
(),
"No file to be read!"
);
const
size_t
thread_num
=
Attr
<
int
>
(
"thread_num"
);
const
size_t
thread_num
=
Attr
<
int
>
(
"thread_num"
);
const
size_t
buffer_size
=
Attr
<
int
>
(
"buffer_size"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
MultipleReader
(
out
->
Reset
(
new
MultiFileReader
(
file_names
,
file_names
,
RestoreShapes
(
shape_concat
,
ranks
),
thread_num
));
RestoreShapes
(
shape_concat
,
ranks
),
thread_num
,
buffer_size
));
}
}
};
};
...
@@ -212,11 +201,12 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
...
@@ -212,11 +201,12 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"file_names"
,
"Files to be read."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"file_names"
,
"Files to be read."
);
AddAttr
<
int
>
(
"thread_num"
,
"The maximal concurrent prefetch thread number."
)
AddAttr
<
int
>
(
"thread_num"
,
"The maximal concurrent prefetch thread number."
)
.
GreaterThan
(
0
);
.
GreaterThan
(
0
);
AddAttr
<
int
>
(
"buffer_size"
,
"The size of prefetch buffer."
).
GreaterThan
(
0
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
OpenFiles Operator
OpenFiles Operator
An OpenFilesOp creates a Multi
p
leReader, which is able to
An OpenFilesOp creates a Multi
Fi
leReader, which is able to
read data multi-threaded from multiple files.
read data multi-threaded from multiple files.
)DOC"
);
)DOC"
);
}
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
90084a25
...
@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"has_next"
,
&
framework
::
ReaderHolder
::
HasNext
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
ReInit
);
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
ReInit
);
py
::
class_
<
Scope
>
(
m
,
"Scope"
,
""
)
py
::
class_
<
Scope
>
(
m
,
"Scope"
,
""
)
...
...
paddle/fluid/pybind/recordio.cc
浏览文件 @
90084a25
...
@@ -39,7 +39,7 @@ class RecordIOWriter {
...
@@ -39,7 +39,7 @@ class RecordIOWriter {
void
CompleteAppendTensor
()
{
void
CompleteAppendTensor
()
{
auto
&
ctx
=
auto
&
ctx
=
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
framework
::
WriteToRecordIO
(
writer_
,
tensors_
,
ctx
);
framework
::
WriteToRecordIO
(
&
writer_
,
tensors_
,
ctx
);
tensors_
.
clear
();
tensors_
.
clear
();
}
}
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
90084a25
...
@@ -21,8 +21,7 @@ from ..executor import global_scope
...
@@ -21,8 +21,7 @@ from ..executor import global_scope
__all__
=
[
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'open_files'
,
'read_file'
,
'create_shuffle_reader'
,
'open_files'
,
'read_file'
,
'shuffle'
,
'double_buffer'
'create_double_buffer_reader'
,
'create_multi_pass_reader'
]
]
...
@@ -237,13 +236,9 @@ def monkey_patch_reader_methods(reader):
...
@@ -237,13 +236,9 @@ def monkey_patch_reader_methods(reader):
var
=
scope
.
find_var
(
reader
.
name
)
var
=
scope
.
find_var
(
reader
.
name
)
return
var
.
get_reader
()
return
var
.
get_reader
()
def
eof
():
return
not
__get_reader__
().
has_next
()
def
reset
():
def
reset
():
return
__get_reader__
().
reset
()
return
__get_reader__
().
reset
()
reader
.
eof
=
eof
reader
.
reset
=
reset
reader
.
reset
=
reset
reader
.
stop_gradient
=
True
reader
.
stop_gradient
=
True
reader
.
persistable
=
True
reader
.
persistable
=
True
...
@@ -283,7 +278,42 @@ def _copy_reader_create_op_(block, op):
...
@@ -283,7 +278,42 @@ def _copy_reader_create_op_(block, op):
return
new_op
return
new_op
def
open_recordio_file
(
filename
,
shapes
,
lod_levels
,
dtypes
):
def
open_recordio_file
(
filename
,
shapes
,
lod_levels
,
dtypes
,
pass_num
=
1
,
for_parallel
=
False
):
"""
Open a RecordIO file
This layer takes a RecordIO file to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from the given RecordIO file.
Args:
filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get RecordIO file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_recordio_file(
filename='./data.recordio',
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
"""
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
shape_concat
=
[]
shape_concat
=
[]
ranks
=
[]
ranks
=
[]
...
@@ -310,10 +340,63 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
...
@@ -310,10 +340,63 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var
.
persistable
=
True
startup_var
.
persistable
=
True
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
startup_var
)
if
pass_num
>
1
:
main_prog_var
=
multi_pass
(
reader
=
main_prog_var
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_var
=
parallel
(
reader
=
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
def
open_files
(
filenames
,
thread_num
,
shapes
,
lod_levels
,
dtypes
):
def
open_files
(
filenames
,
shapes
,
lod_levels
,
dtypes
,
thread_num
,
buffer_size
=
None
,
pass_num
=
1
,
for_parallel
=
False
):
"""
Open files
This layer takes a list of files to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from given files. All files must
have name suffixs to indicate their formats, e.g., '*.recordio'.
Args:
filenames(list): The list of file names.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_files(filenames=['./data1.recordio',
'./data2.recordio'],
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
if
buffer_size
is
None
:
buffer_size
=
thread_num
if
isinstance
(
filenames
,
basestring
):
filenames
=
[
filenames
]
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
shape_concat
=
[]
shape_concat
=
[]
ranks
=
[]
ranks
=
[]
...
@@ -322,29 +405,36 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
...
@@ -322,29 +405,36 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
shape_concat
.
extend
(
shape
)
shape_concat
.
extend
(
shape
)
ranks
.
append
(
len
(
shape
))
ranks
.
append
(
len
(
shape
))
var_name
=
unique_name
(
'multiple_reader'
)
multi_file_reader_name
=
unique_name
(
'multi_file_reader'
)
startup_blk
=
default_startup_program
().
current_block
()
startup_blk
=
default_startup_program
().
current_block
()
startup_
var
=
startup_blk
.
create_var
(
name
=
va
r_name
)
startup_
reader
=
startup_blk
.
create_var
(
name
=
multi_file_reade
r_name
)
startup_blk
.
append_op
(
startup_blk
.
append_op
(
type
=
'open_files'
,
type
=
'open_files'
,
outputs
=
{
'Out'
:
[
startup_
va
r
]},
outputs
=
{
'Out'
:
[
startup_
reade
r
]},
attrs
=
{
attrs
=
{
'shape_concat'
:
shape_concat
,
'shape_concat'
:
shape_concat
,
'lod_levels'
:
lod_levels
,
'lod_levels'
:
lod_levels
,
'ranks'
:
ranks
,
'ranks'
:
ranks
,
'file_names'
:
filenames
,
'file_names'
:
filenames
,
'thread_num'
:
thread_num
'thread_num'
:
thread_num
,
'buffer_size'
:
buffer_size
})
})
startup_var
.
desc
.
set_dtypes
(
dtypes
)
startup_reader
.
desc
.
set_dtypes
(
dtypes
)
startup_var
.
persistable
=
True
startup_reader
.
persistable
=
True
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
main_prog_reader
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
startup_reader
)
return
monkey_patch_reader_methods
(
main_prog_var
)
if
pass_num
>
1
:
main_prog_reader
=
multi_pass
(
reader
=
main_prog_reader
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_reader
=
parallel
(
reader
=
main_prog_reader
)
return
monkey_patch_reader_methods
(
main_prog_reader
)
def
__create_decorated_reader__
(
op_type
,
reader
,
attrs
):
def
__create_
shared_
decorated_reader__
(
op_type
,
reader
,
attrs
):
var_name
=
unique_name
(
op_type
)
var_name
=
unique_name
(
op_type
)
startup_blk
=
default_startup_program
().
current_block
()
startup_blk
=
default_startup_program
().
current_block
()
startup_var
=
startup_blk
.
create_var
(
name
=
var_name
)
startup_var
=
startup_blk
.
create_var
(
name
=
var_name
)
...
@@ -360,22 +450,41 @@ def __create_decorated_reader__(op_type, reader, attrs):
...
@@ -360,22 +450,41 @@ def __create_decorated_reader__(op_type, reader, attrs):
return
monkey_patch_reader_methods
(
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
def
create_shuffle_reader
(
reader
,
buffer_size
):
def
__create_unshared_decorated_reader__
(
op_type
,
reader
,
attrs
):
return
__create_decorated_reader__
(
'create_shuffle_reader'
,
reader
,
new_reader_name
=
unique_name
(
op_type
)
{
'buffer_size'
:
int
(
buffer_size
)})
main_blk
=
default_main_program
().
current_block
()
new_reader
=
main_blk
.
create_var
(
name
=
new_reader_name
)
main_blk
.
append_op
(
type
=
op_type
,
inputs
=
{
'UnderlyingReader'
:
reader
},
outputs
=
{
'Out'
:
[
new_reader
]},
attrs
=
attrs
)
new_reader
.
persistable
=
True
new_reader
.
stop_gradient
=
True
return
monkey_patch_reader_methods
(
new_reader
)
def
shuffle
(
reader
,
buffer_size
):
return
__create_unshared_decorated_reader__
(
'create_shuffle_reader'
,
reader
,
{
'buffer_size'
:
int
(
buffer_size
)})
def
create_double_buffer_read
er
(
reader
,
place
=
None
):
def
double_buff
er
(
reader
,
place
=
None
):
attrs
=
dict
()
attrs
=
dict
()
if
place
is
not
None
:
if
place
is
not
None
:
attrs
[
'place'
]
=
str
(
place
).
upper
()
attrs
[
'place'
]
=
str
(
place
).
upper
()
return
__create_decorated_reader__
(
'create_double_buffer_reader'
,
reader
,
return
__create_unshared_decorated_reader__
(
'create_double_buffer_reader'
,
attrs
)
reader
,
attrs
)
def
multi_pass
(
reader
,
pass_num
):
return
__create_shared_decorated_reader__
(
'create_multi_pass_reader'
,
reader
,
{
'pass_num'
:
int
(
pass_num
)})
def
create_multi_pass_reader
(
reader
,
pass_num
):
def
parallel
(
reader
):
return
__create_
decorated_reader__
(
'create_multi_pass
_reader'
,
reader
,
return
__create_
shared_decorated_reader__
(
'create_threaded
_reader'
,
reader
,
{
'pass_num'
:
int
(
pass_num
)
})
{
})
def
read_file
(
file_obj
):
def
read_file
(
file_obj
):
...
...
python/paddle/fluid/tests/unittests/test_multi
p
le_reader.py
→
python/paddle/fluid/tests/unittests/test_multi
_fi
le_reader.py
浏览文件 @
90084a25
...
@@ -61,8 +61,12 @@ class TestMultipleReader(unittest.TestCase):
...
@@ -61,8 +61,12 @@ class TestMultipleReader(unittest.TestCase):
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
batch_count
=
0
batch_count
=
0
while
not
data_files
.
eof
():
while
True
:
img_val
,
=
exe
.
run
(
fetch_list
=
[
img
])
try
:
img_val
,
=
exe
.
run
(
fetch_list
=
[
img
])
except
fluid
.
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"There is no next data."
,
ex
.
message
)
break
batch_count
+=
1
batch_count
+=
1
self
.
assertLessEqual
(
img_val
.
shape
[
0
],
self
.
batch_size
)
self
.
assertLessEqual
(
img_val
.
shape
[
0
],
self
.
batch_size
)
data_files
.
reset
()
data_files
.
reset
()
...
...
python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
浏览文件 @
90084a25
...
@@ -44,7 +44,7 @@ class TestMultipleReader(unittest.TestCase):
...
@@ -44,7 +44,7 @@ class TestMultipleReader(unittest.TestCase):
shapes
=
[(
-
1
,
784
),
(
-
1
,
1
)],
shapes
=
[(
-
1
,
784
),
(
-
1
,
1
)],
lod_levels
=
[
0
,
0
],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
dtypes
=
[
'float32'
,
'int64'
])
data_file
=
fluid
.
layers
.
create_multi_pass_reader
(
data_file
=
fluid
.
layers
.
io
.
multi_pass
(
reader
=
data_file
,
pass_num
=
self
.
pass_num
)
reader
=
data_file
,
pass_num
=
self
.
pass_num
)
img
,
label
=
fluid
.
layers
.
read_file
(
data_file
)
img
,
label
=
fluid
.
layers
.
read_file
(
data_file
)
...
@@ -57,8 +57,12 @@ class TestMultipleReader(unittest.TestCase):
...
@@ -57,8 +57,12 @@ class TestMultipleReader(unittest.TestCase):
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
batch_count
=
0
batch_count
=
0
while
not
data_file
.
eof
():
while
True
:
img_val
,
=
exe
.
run
(
fetch_list
=
[
img
])
try
:
img_val
,
=
exe
.
run
(
fetch_list
=
[
img
])
except
fluid
.
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"There is no next data."
,
ex
.
message
)
break
batch_count
+=
1
batch_count
+=
1
self
.
assertLessEqual
(
img_val
.
shape
[
0
],
self
.
batch_size
)
self
.
assertLessEqual
(
img_val
.
shape
[
0
],
self
.
batch_size
)
data_file
.
reset
()
data_file
.
reset
()
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
90084a25
...
@@ -26,11 +26,14 @@ def simple_fc_net(use_feed):
...
@@ -26,11 +26,14 @@ def simple_fc_net(use_feed):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
else
:
else
:
reader
=
fluid
.
layers
.
open_
recordio_file
(
reader
=
fluid
.
layers
.
open_
files
(
filename
=
'./mnist.recordio'
,
filename
s
=
[
'./mnist.recordio'
]
,
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
dtypes
=
[
'float32'
,
'int64'
],
thread_num
=
1
,
for_parallel
=
True
)
reader
=
fluid
.
layers
.
io
.
double_buffer
(
reader
)
img
,
label
=
fluid
.
layers
.
read_file
(
reader
)
img
,
label
=
fluid
.
layers
.
read_file
(
reader
)
hidden
=
img
hidden
=
img
for
_
in
xrange
(
4
):
for
_
in
xrange
(
4
):
...
@@ -51,11 +54,14 @@ def fc_with_batchnorm(use_feed):
...
@@ -51,11 +54,14 @@ def fc_with_batchnorm(use_feed):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
else
:
else
:
reader
=
fluid
.
layers
.
open_
recordio_file
(
reader
=
fluid
.
layers
.
open_
files
(
filename
=
'./mnist.recordio'
,
filename
s
=
[
'mnist.recordio'
]
,
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
dtypes
=
[
'float32'
,
'int64'
],
thread_num
=
1
,
for_parallel
=
True
)
reader
=
fluid
.
layers
.
io
.
double_buffer
(
reader
)
img
,
label
=
fluid
.
layers
.
read_file
(
reader
)
img
,
label
=
fluid
.
layers
.
read_file
(
reader
)
hidden
=
img
hidden
=
img
...
...
python/paddle/fluid/tests/unittests/test_recordio_reader.py
浏览文件 @
90084a25
...
@@ -65,8 +65,13 @@ class TestRecordIO(unittest.TestCase):
...
@@ -65,8 +65,13 @@ class TestRecordIO(unittest.TestCase):
# train a pass
# train a pass
batch_id
=
0
batch_id
=
0
while
not
data_file
.
eof
():
while
True
:
tmp
,
=
exe
.
run
(
fetch_list
=
[
avg_loss
])
try
:
tmp
,
=
exe
.
run
(
fetch_list
=
[
avg_loss
])
except
fluid
.
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"There is no next data."
,
ex
.
message
)
break
avg_loss_np
.
append
(
tmp
)
avg_loss_np
.
append
(
tmp
)
batch_id
+=
1
batch_id
+=
1
data_file
.
reset
()
data_file
.
reset
()
...
@@ -74,8 +79,8 @@ class TestRecordIO(unittest.TestCase):
...
@@ -74,8 +79,8 @@ class TestRecordIO(unittest.TestCase):
self
.
assertLess
(
avg_loss_np
[
-
1
],
avg_loss_np
[
0
])
self
.
assertLess
(
avg_loss_np
[
-
1
],
avg_loss_np
[
0
])
def
test_shuffle_reader
(
self
):
def
test_shuffle_reader
(
self
):
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
create_shuffle_reader
(
reader
,
buffer_size
=
200
))
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
io
.
shuffle
(
reader
,
buffer_size
=
200
))
def
test_double_buffer_reader
(
self
):
def
test_double_buffer_reader
(
self
):
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
create_double_buffer_read
er
(
reader
,
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
io
.
double_buff
er
(
reader
,
place
=
'cuda:0'
if
fluid
.
core
.
is_compiled_with_cuda
()
else
'cpu'
))
place
=
'cuda:0'
if
fluid
.
core
.
is_compiled_with_cuda
()
else
'cpu'
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录