Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
72be7a61
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
72be7a61
编写于
3月 07, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete RecordIO reader op
上级
bcb80756
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
395 addition
and
27 deletion
+395
-27
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+7
-1
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+87
-0
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+1
-1
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-0
paddle/fluid/pybind/recordio.cc
paddle/fluid/pybind/recordio.cc
+69
-0
paddle/fluid/pybind/recordio.h
paddle/fluid/pybind/recordio.h
+26
-0
paddle/fluid/recordio/chunk.cc
paddle/fluid/recordio/chunk.cc
+20
-20
paddle/fluid/recordio/header.cc
paddle/fluid/recordio/header.cc
+2
-2
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-0
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+58
-2
python/paddle/fluid/recordio_writer.py
python/paddle/fluid/recordio_writer.py
+62
-0
python/paddle/fluid/tests/unittests/test_recordio_reader.py
python/paddle/fluid/tests/unittests/test_recordio_reader.py
+56
-0
未找到文件。
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
72be7a61
...
...
@@ -2,4 +2,10 @@ cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_regist
op_library
(
create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry
)
op_library
(
create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry
)
op_library
(
create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry
)
set
(
READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op PARENT_SCOPE
)
op_library
(
create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc DEPS reader_op_registry
)
set
(
READER_LIBRARY
create_recordio_file_reader_op
create_random_data_generator_op
create_shuffle_reader_op
create_batch_reader_op
PARENT_SCOPE
)
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
0 → 100644
浏览文件 @
72be7a61
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
RecordIOFileReader
:
public
framework
::
FileReader
{
public:
RecordIOFileReader
(
const
std
::
string
&
filename
,
const
std
::
vector
<
framework
::
DDim
>&
shapes
)
:
FileReader
(
shapes
),
scanner_
(
filename
),
dev_ctx_
(
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()))
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
}
bool
HasNext
()
const
override
{
return
scanner_
.
HasNext
();
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
private:
recordio
::
Scanner
scanner_
;
const
platform
::
DeviceContext
&
dev_ctx_
;
};
class
CreateRecordIOReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RecordIOFileReader
(
filename
,
shapes
));
}
};
class
CreateRecordIOReaderOpMaker
:
public
FileReaderMakerBase
{
public:
CreateRecordIOReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
FileReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
std
::
string
>
(
"filename"
,
"The filename of record io reader"
);
AddComment
(
R"DOC(
CreateRecordIOReader Operator
Create a reader from a record io file
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
reader
=
paddle
::
operators
::
reader
;
REGISTER_FILE_READER_OPERATOR
(
create_recordio_file_reader
,
reader
::
CreateRecordIOReaderOp
,
reader
::
CreateRecordIOReaderOpMaker
);
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
72be7a61
...
...
@@ -35,7 +35,7 @@ FileReaderMakerBase::FileReaderMakerBase(
framework
::
OpProtoAndCheckerMaker
::
OpProto
*
op_proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"(ReaderHolder) The created random reader."
);
AddOutput
(
"Out"
,
"(ReaderHolder) The created random reader."
)
.
AsDuplicable
()
;
AddAttr
<
std
::
vector
<
int
>>
(
"shape_concat"
,
"The concat of all data's shapes."
);
AddAttr
<
std
::
vector
<
int
>>
(
"ranks"
,
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
72be7a61
if
(
WITH_PYTHON
)
cc_library
(
paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
recordio.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method
${
GLOB_OP_LIB
}
)
if
(
NOT APPLE AND NOT ANDROID
)
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
72be7a61
...
...
@@ -35,7 +35,9 @@ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h"
#ifdef PADDLE_WITH_CUDA
...
...
@@ -474,6 +476,8 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"enable_profiler"
,
platform
::
EnableProfiler
);
m
.
def
(
"disable_profiler"
,
platform
::
DisableProfiler
);
m
.
def
(
"reset_profiler"
,
platform
::
ResetProfiler
);
BindRecordIOWriter
(
m
);
return
m
.
ptr
();
}
}
// namespace pybind
...
...
paddle/fluid/pybind/recordio.cc
0 → 100644
浏览文件 @
72be7a61
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/recordio.h"
#include <fstream>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/writer.h"
namespace
paddle
{
namespace
pybind
{
class
RecordIOWriter
{
public:
RecordIOWriter
(
const
std
::
string
&
filename
,
recordio
::
Compressor
compressor
,
size_t
max_num_record
)
:
stream_
(
filename
),
writer_
(
&
stream_
,
compressor
,
max_num_record
)
{}
void
AppendTensor
(
const
framework
::
LoDTensor
&
tensor
)
{
tensors_
.
push_back
(
tensor
);
}
void
CompleteAppendTensor
()
{
auto
&
ctx
=
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
framework
::
WriteToRecordIO
(
writer_
,
tensors_
,
ctx
);
tensors_
.
clear
();
}
void
Close
()
{
PADDLE_ENFORCE
(
tensors_
.
empty
());
writer_
.
Flush
();
stream_
.
close
();
}
private:
std
::
vector
<
framework
::
LoDTensor
>
tensors_
;
std
::
ofstream
stream_
;
recordio
::
Writer
writer_
;
};
void
BindRecordIOWriter
(
py
::
module
&
m
)
{
py
::
class_
<
RecordIOWriter
>
writer
(
m
,
"RecordIOWriter"
,
""
);
py
::
enum_
<
recordio
::
Compressor
>
(
writer
,
"Compressor"
,
""
)
.
value
(
"Snappy"
,
recordio
::
Compressor
::
kSnappy
)
.
value
(
"NoCompress"
,
recordio
::
Compressor
::
kNoCompress
);
writer
.
def
(
"__init__"
,
[](
RecordIOWriter
&
self
,
const
std
::
string
&
filename
,
recordio
::
Compressor
compressor
,
size_t
max_num_record
)
{
new
(
&
self
)
RecordIOWriter
(
filename
,
compressor
,
max_num_record
);
})
.
def
(
"append_tensor"
,
&
RecordIOWriter
::
AppendTensor
)
.
def
(
"complete_append_tensor"
,
&
RecordIOWriter
::
CompleteAppendTensor
)
.
def
(
"close"
,
&
RecordIOWriter
::
Close
);
}
}
// namespace pybind
}
// namespace paddle
paddle/fluid/pybind/recordio.h
0 → 100644
浏览文件 @
72be7a61
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
pybind
{
extern
void
BindRecordIOWriter
(
py
::
module
&
m
);
}
// namespace pybind
}
// namespace paddle
paddle/fluid/recordio/chunk.cc
浏览文件 @
72be7a61
...
...
@@ -25,32 +25,36 @@ namespace recordio {
constexpr
size_t
kMaxBufSize
=
1024
;
template
<
typename
Callback
>
static
void
ReadStreamByBuf
(
std
::
istream
&
in
,
in
t
limit
,
Callback
callback
)
{
static
void
ReadStreamByBuf
(
std
::
istream
&
in
,
size_
t
limit
,
Callback
callback
)
{
char
buf
[
kMaxBufSize
];
std
::
streamsize
actual_size
;
size_t
counter
=
0
;
do
{
auto
actual_max
=
limit
>
0
?
std
::
min
(
limit
-
counter
,
kMaxBufSize
)
:
kMaxBufSize
;
actual_size
=
in
.
readsome
(
buf
,
actual_max
);
size_t
actual_max
;
while
(
!
in
.
eof
()
||
(
limit
!=
0
&&
counter
>=
limit
))
{
actual_max
=
limit
!=
0
?
std
::
min
(
limit
-
counter
,
kMaxBufSize
)
:
kMaxBufSize
;
in
.
read
(
buf
,
actual_max
);
actual_size
=
in
.
gcount
();
if
(
actual_size
==
0
)
{
break
;
}
callback
(
buf
,
actual_size
);
if
(
limit
>
0
)
{
if
(
limit
!=
0
)
{
counter
+=
actual_size
;
}
}
while
(
actual_size
==
kMaxBufSize
);
}
in
.
clear
();
// unset eof state
}
static
void
PipeStream
(
std
::
istream
&
in
,
std
::
ostream
&
os
)
{
ReadStreamByBuf
(
in
,
-
1
,
[
&
os
](
const
char
*
buf
,
size_t
len
)
{
os
.
write
(
buf
,
len
);
});
in
,
0
,
[
&
os
](
const
char
*
buf
,
size_t
len
)
{
os
.
write
(
buf
,
len
);
});
}
static
uint32_t
Crc32Stream
(
std
::
istream
&
in
,
int
limit
=
-
1
)
{
auto
crc
=
crc32
(
0
,
nullptr
,
0
);
static
uint32_t
Crc32Stream
(
std
::
istream
&
in
,
size_t
limit
=
0
)
{
uint32_t
crc
=
static_cast
<
uint32_t
>
(
crc32
(
0
,
nullptr
,
0
)
);
ReadStreamByBuf
(
in
,
limit
,
[
&
crc
](
const
char
*
buf
,
size_t
len
)
{
crc
=
crc32
(
crc
,
reinterpret_cast
<
const
Bytef
*>
(
buf
),
len
);
crc
=
static_cast
<
uint32_t
>
(
crc32
(
crc
,
reinterpret_cast
<
const
Bytef
*>
(
buf
),
static_cast
<
uInt
>
(
len
)));
});
return
crc
;
}
...
...
@@ -85,14 +89,12 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const {
compressed_stream
.
reset
();
}
auto
end_pos
=
sout
.
tellg
();
sout
.
seekg
(
0
,
std
::
ios
::
beg
);
uint32_t
len
=
static_cast
<
uint32_t
>
(
end_pos
-
sout
.
tellg
());
uint32_t
len
=
static_cast
<
uint32_t
>
(
sout
.
str
().
size
());
uint32_t
crc
=
Crc32Stream
(
sout
);
sout
.
seekg
(
0
,
std
::
ios
::
beg
);
Header
hdr
(
static_cast
<
uint32_t
>
(
records_
.
size
()),
crc
,
ct
,
len
);
hdr
.
Write
(
os
);
sout
.
seekg
(
0
,
std
::
ios
::
beg
);
sout
.
clear
();
PipeStream
(
sout
,
os
);
return
true
;
}
...
...
@@ -104,12 +106,10 @@ bool Chunk::Parse(std::istream& sin) {
return
ok
;
}
auto
beg_pos
=
sin
.
tellg
();
auto
crc
=
Crc32Stream
(
sin
,
hdr
.
CompressSize
());
uint32_t
crc
=
Crc32Stream
(
sin
,
hdr
.
CompressSize
());
PADDLE_ENFORCE_EQ
(
hdr
.
Checksum
(),
crc
);
Clear
();
sin
.
seekg
(
beg_pos
,
std
::
ios
::
beg
);
sin
.
seekg
(
beg_pos
,
sin
.
beg
);
std
::
unique_ptr
<
std
::
istream
>
compressed_stream
;
switch
(
hdr
.
CompressType
())
{
case
Compressor
::
kNoCompress
:
...
...
paddle/fluid/recordio/header.cc
浏览文件 @
72be7a61
...
...
@@ -52,8 +52,8 @@ void Header::Write(std::ostream& os) const {
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Header
h
)
{
os
<<
h
.
NumRecords
()
<<
h
.
Checksum
()
<<
static_cast
<
uint32_t
>
(
h
.
CompressType
())
<<
h
.
CompressSize
();
os
<<
"Header: "
<<
h
.
NumRecords
()
<<
", "
<<
h
.
Checksum
()
<<
", "
<<
static_cast
<
uint32_t
>
(
h
.
CompressType
())
<<
", "
<<
h
.
CompressSize
();
return
os
;
}
...
...
python/paddle/fluid/__init__.py
浏览文件 @
72be7a61
...
...
@@ -39,6 +39,7 @@ import clip
from
memory_optimization_transpiler
import
memory_optimize
import
profiler
import
unique_name
import
recordio_writer
Tensor
=
LoDTensor
...
...
@@ -64,6 +65,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'memory_optimize'
,
'profiler'
,
'unique_name'
,
'recordio_writer'
,
]
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
72be7a61
...
...
@@ -13,11 +13,15 @@
# limitations under the License.
from
..
import
core
from
..layer_helper
import
LayerHelper
from
..framework
import
convert_np_dtype_to_dtype_
,
default_main_program
,
default_startup_program
from
..unique_name
import
generate
as
unique_name
from
control_flow
import
BlockGuard
from
..layer_helper
import
LayerHelper
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
]
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'read_file'
]
def
data
(
name
,
...
...
@@ -224,3 +228,55 @@ def Recv(endpoints, get_vars):
outputs
=
{
"Out"
:
get_vars
},
attrs
=
{
"endpoints"
:
endpoints
,
"epmap"
:
epmap
})
def
_copy_reader_var_
(
block
,
var
):
new_var
=
block
.
create_var
(
name
=
var
.
name
,
type
=
core
.
VarDesc
.
VarType
.
READER
)
new_var
.
desc
.
set_shapes
(
var
.
desc
.
shapes
())
new_var
.
desc
.
set_dtypes
(
var
.
desc
.
dtypes
())
new_var
.
persistable
=
True
return
new_var
def
open_recordio_file
(
filename
,
shapes
,
lod_levels
,
dtypes
):
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
shape_concat
=
[]
ranks
=
[]
for
shape
in
shapes
:
shape_concat
.
extend
(
shape
)
ranks
.
append
(
len
(
shape
))
var_name
=
unique_name
(
'open_recordio_file'
)
startup_blk
=
default_startup_program
().
current_block
()
startup_var
=
startup_blk
.
create_var
(
name
=
var_name
)
startup_blk
.
append_op
(
type
=
'create_recordio_file_reader'
,
outputs
=
{
'Out'
:
[
startup_var
]},
attrs
=
{
'shape_concat'
:
shape_concat
,
'lod_levels'
:
lod_levels
,
'filename'
:
filename
,
'ranks'
:
ranks
})
startup_var
.
desc
.
set_dtypes
(
dtypes
)
startup_var
.
persistable
=
True
return
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
def
read_file
(
file_obj
):
helper
=
LayerHelper
(
'read_file'
)
out
=
[
helper
.
create_tmp_variable
(
stop_gradient
=
True
,
dtype
=
'float32'
)
for
i
in
range
(
len
(
file_obj
.
desc
.
shapes
()))
]
helper
.
append_op
(
type
=
'read'
,
inputs
=
{
'Reader'
:
[
file_obj
]},
outputs
=
{
'Out'
:
out
})
if
len
(
out
)
==
1
:
return
out
[
0
]
else
:
return
out
python/paddle/fluid/recordio_writer.py
0 → 100644
浏览文件 @
72be7a61
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
core
class
RecordIOWriter
(
object
):
def
__init__
(
self
,
filename
,
compressor
=
core
.
RecordIOWriter
.
Compressor
.
Snappy
,
max_num_records
=
1000
):
self
.
filename
=
filename
self
.
compressor
=
compressor
self
.
max_num_records
=
max_num_records
self
.
writer
=
None
def
__enter__
(
self
):
self
.
writer
=
core
.
RecordIOWriter
(
self
.
filename
,
self
.
compressor
,
self
.
max_num_records
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
if
exc_type
is
not
None
:
return
False
else
:
self
.
writer
.
close
()
def
append_tensor
(
self
,
tensor
):
self
.
writer
.
append_tensor
(
tensor
)
def
complete_append_tensor
(
self
):
self
.
writer
.
complete_append_tensor
()
def
convert_reader_to_recordio_file
(
filename
,
reader_creator
,
feeder
,
compressor
=
core
.
RecordIOWriter
.
Compressor
.
Snappy
,
max_num_records
=
1000
,
feed_order
=
None
):
writer
=
RecordIOWriter
(
filename
,
compressor
,
max_num_records
)
with
writer
:
for
batch
in
reader_creator
():
res
=
feeder
.
feed
(
batch
)
if
feed_order
is
None
:
for
each
in
res
:
writer
.
append_tensor
(
res
[
each
])
else
:
for
each
in
feed_order
:
writer
.
append_tensor
(
res
[
each
])
writer
.
complete_append_tensor
()
python/paddle/fluid/tests/unittests/test_recordio_reader.py
0 → 100644
浏览文件 @
72be7a61
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
paddle.v2.dataset.mnist
as
mnist
import
paddle.v2
as
paddle
class
TestRecordIO
(
unittest
.
TestCase
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
reader
=
paddle
.
batch
(
mnist
.
train
(),
batch_size
=
32
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
]),
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
],
place
=
fluid
.
CPUPlace
())
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
'./mnist.recordio'
,
reader
,
feeder
,
feed_order
=
[
'image'
,
'label'
])
def
testMain
(
self
):
data_file
=
fluid
.
layers
.
open_recordio_file
(
'./mnist.recordio'
,
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
img
,
label
=
fluid
.
layers
.
read_file
(
data_file
)
hidden
=
fluid
.
layers
.
fc
(
input
=
img
,
size
=
100
,
act
=
'tanh'
)
prediction
=
fluid
.
layers
.
fc
(
input
=
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
fluid
.
optimizer
.
SGD
(
learning_rate
=
1e-3
).
minimize
(
avg_loss
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
fluid
.
default_startup_program
())
avg_loss_np
,
=
exe
.
run
(
fetch_list
=
[
avg_loss
])
print
avg_loss_np
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录