Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
42e65a20
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
42e65a20
编写于
3月 07, 2018
作者:
Y
Yu Yang
提交者:
GitHub
3月 07, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8791 from reyoung/feature/extract_reader_ops
Extract create_reader_op to three files
上级
87568cfd
4690b9c9
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
553 addition
and
415 deletion
+553
-415
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+0
-87
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+2
-77
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+8
-4
paddle/fluid/operators/create_reader_op.cc
paddle/fluid/operators/create_reader_op.cc
+0
-246
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+3
-1
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+5
-0
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+137
-0
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+110
-0
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+97
-0
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+116
-0
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+75
-0
未找到文件。
paddle/fluid/framework/reader.cc
浏览文件 @
42e65a20
...
...
@@ -25,92 +25,5 @@ DDim ReaderBase::shape(size_t idx) const {
return
shapes_
[
idx
];
}
void
ShuffleReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
// Reload buffer with new data
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
for
(
int
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
reader_
->
HasNext
())
{
buffer_
.
push_back
(
std
::
vector
<
LoDTensor
>
());
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
else
{
break
;
}
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std
::
random_shuffle
(
buffer_
.
begin
(),
buffer_
.
end
());
iteration_pos_
=
0
;
}
out
->
clear
();
if
(
!
buffer_
.
empty
())
{
std
::
swap
(
*
out
,
buffer_
[
iteration_pos_
++
]);
}
// if buffer_ is empty, the 'out' will return as an empty vector.
}
void
BatchReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
buffer_
.
clear
();
buffer_
.
reserve
(
batch_size_
);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
if
(
reader_
->
HasNext
())
{
buffer_
.
push_back
(
std
::
vector
<
LoDTensor
>
());
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
else
{
break
;
}
}
// Concat instances
out
->
clear
();
if
(
buffer_
.
empty
())
{
// if buffer_ is empty, the 'out' will return as an empty vector.
return
;
}
int
out_num
=
buffer_
[
0
].
size
();
out
->
reserve
(
out_num
);
for
(
int
j
=
0
;
j
<
out_num
;
++
j
)
{
// Merge shape and check date type
std
::
type_index
batch_type
=
buffer_
[
0
][
j
].
type
();
DDim
batch_shape
=
buffer_
[
0
][
j
].
dims
();
for
(
size_t
i
=
1
;
i
<
buffer_
.
size
();
++
i
)
{
std
::
type_index
ins_type
=
buffer_
[
i
][
j
].
type
();
DDim
ins_shape
=
buffer_
[
i
][
j
].
dims
();
PADDLE_ENFORCE_EQ
(
batch_type
,
ins_type
);
PADDLE_ENFORCE_EQ
(
slice_ddim
(
batch_shape
,
1
,
batch_shape
.
size
()),
slice_ddim
(
ins_shape
,
1
,
ins_shape
.
size
()));
PADDLE_ENFORCE_GT
(
ins_shape
[
0
],
0
);
batch_shape
[
0
]
+=
ins_shape
[
0
];
}
LoDTensor
out_tensor
;
out_tensor
.
Resize
(
batch_shape
);
out_tensor
.
mutable_data
(
platform
::
CPUPlace
(),
batch_type
);
int64_t
dst_offset
=
0
;
// Merge lod and data
LoD
batch_lod
;
for
(
size_t
i
=
0
;
i
<
buffer_
.
size
();
++
i
)
{
DDim
ins_shape
=
buffer_
[
i
][
j
].
dims
();
LoD
ins_lod
=
buffer_
[
i
][
j
].
lod
();
if
(
i
==
0
)
{
batch_lod
=
ins_lod
;
}
else
{
PADDLE_ENFORCE_EQ
(
batch_lod
.
size
(),
ins_lod
.
size
());
for
(
size_t
level_idx
=
0
;
level_idx
<
batch_lod
.
size
();
++
level_idx
)
{
auto
&
lod_level
=
batch_lod
[
level_idx
];
for
(
size_t
k
=
1
;
k
<
ins_lod
[
level_idx
].
size
();
++
k
)
{
lod_level
.
push_back
(
ins_lod
[
level_idx
][
k
]
+
lod_level
.
back
());
}
}
}
Tensor
dst
=
out_tensor
.
Slice
(
dst_offset
,
dst_offset
+
ins_shape
[
0
]);
TensorCopy
(
buffer_
[
i
][
j
],
platform
::
CPUPlace
(),
&
dst
);
dst_offset
+=
ins_shape
[
0
];
}
out_tensor
.
set_lod
(
batch_lod
);
out
->
push_back
(
out_tensor
);
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
42e65a20
...
...
@@ -60,83 +60,8 @@ class DecoratedReader : public ReaderBase {
ReaderBase
*
reader_
;
};
// file readers
template
<
typename
T
>
class
RandomDataGenerator
:
public
FileReader
{
public:
RandomDataGenerator
(
const
std
::
vector
<
DDim
>&
shapes
,
float
min
,
float
max
)
:
FileReader
(
shapes
),
min_
(
min
),
max_
(
max
)
{
PADDLE_ENFORCE_LE
(
min
,
max
,
"'min' shouldn't be greater than 'max'.(%f vs %f)"
,
min
,
max
);
unsigned
int
seed
=
std
::
random_device
()();
engine_
.
seed
(
seed
);
dist_
=
std
::
uniform_real_distribution
<
float
>
(
min_
,
max_
);
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
override
{
out
->
clear
();
out
->
reserve
(
shapes_
.
size
());
for
(
const
DDim
&
shape
:
shapes_
)
{
PADDLE_ENFORCE_GE
(
shape
.
size
(),
2
,
"The rank of reader's output data should be 2 at least.(Now it's %d)"
,
shape
.
size
());
LoDTensor
out_tensor
;
out_tensor
.
Resize
(
shape
);
T
*
data
=
out_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
int64_t
numel
=
product
(
shape
);
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
data
[
i
]
=
dist_
(
engine_
);
}
out
->
push_back
(
out_tensor
);
}
}
bool
HasNext
()
const
override
{
return
true
;
}
void
ReInit
()
override
{
return
;
}
private:
float
min_
;
float
max_
;
std
::
minstd_rand
engine_
;
std
::
uniform_real_distribution
<
float
>
dist_
;
};
// decorated readers
class
ShuffleReader
:
public
DecoratedReader
{
public:
ShuffleReader
(
ReaderBase
*
reader
,
int
buffer_size
)
:
DecoratedReader
(
reader
),
buffer_size_
(
buffer_size
),
iteration_pos_
(
0
)
{
buffer_
.
reserve
(
buffer_size
);
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
override
;
private:
int
buffer_size_
;
std
::
vector
<
std
::
vector
<
LoDTensor
>>
buffer_
;
size_t
iteration_pos_
;
};
class
BatchReader
:
public
DecoratedReader
{
public:
BatchReader
(
ReaderBase
*
reader
,
int
batch_size
)
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
)
{
buffer_
.
reserve
(
batch_size_
);
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
override
;
private:
int
batch_size_
;
std
::
vector
<
std
::
vector
<
LoDTensor
>>
buffer_
;
};
// The ReaderHolder is used as readers' unified wrapper,
// making it easier to access different type readers in Variables.
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class
ReaderHolder
{
public:
void
Reset
(
ReaderBase
*
reader
)
{
reader_
.
reset
(
reader
);
}
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
42e65a20
...
...
@@ -70,7 +70,7 @@ function(op_library TARGET)
endif
()
# Define operators that don't need pybind here.
foreach
(
manual_pybind_op
"net_op"
"compare_op"
"logical_op"
"nccl_op"
"tensor_array_read_write_op"
"create_reader_op"
)
foreach
(
manual_pybind_op
"net_op"
"compare_op"
"logical_op"
"nccl_op"
"tensor_array_read_write_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
set
(
pybind_flag 1
)
endif
()
...
...
@@ -128,8 +128,8 @@ else()
set
(
DEPS_OPS
${
DEPS_OPS
}
nccl_op
)
endif
()
add_subdirectory
(
detail
)
if
(
WITH_DISTRIBUTE
)
add_subdirectory
(
detail
)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
op_library
(
send_op DEPS
${
DISTRIBUTE_DEPS
}
)
...
...
@@ -170,7 +170,6 @@ op_library(recurrent_op DEPS executor)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
cos_sim_op DEPS cos_sim_functor
)
op_library
(
parallel_do_op DEPS executor
)
op_library
(
create_reader_op DEPS reader
)
if
(
WITH_GPU
)
op_library
(
conv_op DEPS vol2col depthwise_conv
)
...
...
@@ -189,7 +188,12 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach
(
src
${
GENERAL_OPS
}
)
op_library
(
${
src
}
)
endforeach
()
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(logical_and);
\n
USE_NO_KERNEL_OP(read_from_array);
\n
USE_NO_KERNEL_OP(create_random_data_generator);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(logical_and);
\n
USE_NO_KERNEL_OP(read_from_array);
\n
"
)
add_subdirectory
(
reader
)
foreach
(
src
${
READER_LIBRARY
}
)
set
(
OP_LIBRARY
${
src
}
${
OP_LIBRARY
}
)
endforeach
()
set
(
GLOB_OP_LIB
${
OP_LIBRARY
}
CACHE INTERNAL
"Global OP library"
)
...
...
paddle/fluid/operators/create_reader_op.cc
已删除
100644 → 0
浏览文件 @
87568cfd
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
namespace
paddle
{
namespace
operators
{
static
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
)
{
std
::
vector
<
framework
::
DDim
>
res
;
int
offset
=
0
;
for
(
int
len
:
ranks
)
{
auto
start_it
=
shape_concat
.
begin
()
+
offset
;
auto
end_it
=
start_it
+
len
;
res
.
push_back
(
framework
::
make_ddim
(
std
::
vector
<
int
>
(
start_it
,
end_it
)));
offset
+=
len
;
}
return
res
;
}
// general infershape for file readers
class
CreateFileReaderInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output file reader should not be null."
);
const
auto
shape_concat
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
ranks
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ranks"
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
ctx
->
SetReaderDims
(
"Out"
,
shapes
);
if
(
ctx
->
IsRuntime
())
{
const
auto
lod_levels
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"lod_levels"
);
PADDLE_ENFORCE_EQ
(
lod_levels
.
size
(),
shapes
.
size
(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d)."
,
lod_levels
.
size
(),
shapes
.
size
());
framework
::
VarDesc
*
reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
reader
->
SetLoDLevels
(
lod_levels
);
}
}
};
// general infershape for decorated readers
class
CreateDecoratedReaderInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"UnderlyingReader"
),
"Input(UnderlyingReader) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output decorated reader should not be null."
);
ctx
->
SetReaderDims
(
"Out"
,
ctx
->
GetReaderDims
(
"UnderlyingReader"
));
if
(
ctx
->
IsRuntime
())
{
framework
::
VarDesc
*
in_reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetInputVarPtrs
(
"UnderlyingReader"
)[
0
]);
framework
::
VarDesc
*
out_reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
out_reader
->
SetLoDLevels
(
in_reader
->
GetLoDLevels
());
}
}
};
// general var type inference for file readers
class
CreateFileReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
string
reader_name
=
op_desc
.
Output
(
"Out"
)[
0
];
framework
::
VarDesc
*
reader
=
block
->
FindVarRecursive
(
reader_name
);
reader
->
SetType
(
framework
::
proto
::
VarType
::
READER
);
}
};
// general var type inference for decorated readers
class
CreateDecoratedReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
string
in_reader_name
=
op_desc
.
Input
(
"UnderlyingReader"
)[
0
];
framework
::
VarDesc
*
in_reader
=
block
->
FindVarRecursive
(
in_reader_name
);
std
::
string
out_reader_name
=
op_desc
.
Output
(
"Out"
)[
0
];
framework
::
VarDesc
*
out_reader
=
block
->
FindVarRecursive
(
out_reader_name
);
out_reader
->
SetType
(
framework
::
proto
::
VarType
::
READER
);
out_reader
->
SetDataTypes
(
in_reader
->
GetDataTypes
());
}
};
template
<
typename
T
>
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
framework
::
RandomDataGenerator
<
T
>
(
shapes
,
Attr
<
float
>
(
"min"
),
Attr
<
float
>
(
"max"
)));
}
};
class
CreateRandomDataGeneratorOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
CreateRandomDataGeneratorOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"(ReaderHolder) The created random reader."
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape_concat"
,
"The concat of all data's shapes."
);
AddAttr
<
std
::
vector
<
int
>>
(
"ranks"
,
"The ranks of each data."
"e.g."
"shape_concat = [2,3,4,5,6]"
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."
);
AddAttr
<
std
::
vector
<
int
>>
(
"lod_levels"
,
"The LoD levels of each data."
);
AddAttr
<
float
>
(
"min"
,
"The lower bound of reader's uniform distribution."
);
AddAttr
<
float
>
(
"max"
,
"The upper bound of reader's uniform distribution."
);
AddComment
(
R"DOC(
CreateRandomDataGenerator Operator
This Op creates a random reader.
The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'.
)DOC"
);
}
};
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
framework
::
ShuffleReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"buffer_size"
)));
}
};
class
CreateShuffleReaderOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
CreateShuffleReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddInput
(
"UnderlyingReader"
,
"(ReaderHolder) The underlying reader for creating a shuffle reader."
);
AddOutput
(
"Out"
,
"(ReaderHolder) The created shuffle reader."
);
AddAttr
<
int
>
(
"buffer_size"
,
"The shuffle buffer size."
).
GreaterThan
(
0
);
AddComment
(
R"DOC(
CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader'
and yields the underlying reader's outputs in a shuffled order.
)DOC"
);
}
};
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
framework
::
BatchReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"batch_size"
)));
}
};
class
CreateBatchReaderOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
CreateBatchReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddInput
(
"UnderlyingReader"
,
"(ReaderHolder) The underlying reader for creating a batch reader."
);
AddOutput
(
"Out"
,
"(ReaderHolder) The created batch reader."
);
AddAttr
<
int
>
(
"batch_size"
,
"How many instances the batch reader yields each time."
)
.
GreaterThan
(
0
);
AddComment
(
R"DOC(
CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
create_random_data_generator
,
ops
::
CreateRandomDataGeneratorOp
<
float
>
,
ops
::
CreateFileReaderInferShape
,
ops
::
CreateRandomDataGeneratorOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateFileReaderInferVarType
);
REGISTER_OPERATOR
(
create_shuffle_reader
,
ops
::
CreateShuffleReaderOp
,
ops
::
CreateDecoratedReaderInferShape
,
ops
::
CreateShuffleReaderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateDecoratedReaderInferVarType
);
REGISTER_OPERATOR
(
create_batch_reader
,
ops
::
CreateBatchReaderOp
,
ops
::
CreateDecoratedReaderInferShape
,
ops
::
CreateBatchReaderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateDecoratedReaderInferVarType
);
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
42e65a20
grpc_library
(
sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
if
(
WITH_DISTRIBUTE
)
grpc_library
(
sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
endif
()
paddle/fluid/operators/reader/CMakeLists.txt
0 → 100644
浏览文件 @
42e65a20
cc_library
(
reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader
)
op_library
(
create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry
)
op_library
(
create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry
)
op_library
(
create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry
)
set
(
READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op PARENT_SCOPE
)
paddle/fluid/operators/reader/create_batch_reader_op.cc
0 → 100644
浏览文件 @
42e65a20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
BatchReader
:
public
framework
::
DecoratedReader
{
public:
BatchReader
(
ReaderBase
*
reader
,
int
batch_size
)
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
)
{
buffer_
.
reserve
(
batch_size_
);
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
int
batch_size_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
};
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
BatchReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"batch_size"
)));
}
};
class
CreateBatchReaderOpMaker
:
public
DecoratedReaderMakerBase
{
public:
CreateBatchReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
DecoratedReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
int
>
(
"batch_size"
,
"How many instances the batch reader yields each time."
)
.
GreaterThan
(
0
);
AddComment
(
R"DOC(
CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches.
)DOC"
);
}
};
void
BatchReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
buffer_
.
clear
();
buffer_
.
reserve
(
batch_size_
);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
if
(
reader_
->
HasNext
())
{
buffer_
.
push_back
(
std
::
vector
<
framework
::
LoDTensor
>
());
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
else
{
break
;
}
}
// Concat instances
out
->
clear
();
if
(
buffer_
.
empty
())
{
// if buffer_ is empty, the 'out' will return as an empty vector.
return
;
}
int
out_num
=
buffer_
[
0
].
size
();
out
->
reserve
(
out_num
);
for
(
int
j
=
0
;
j
<
out_num
;
++
j
)
{
// Merge shape and check date type
std
::
type_index
batch_type
=
buffer_
[
0
][
j
].
type
();
framework
::
DDim
batch_shape
=
buffer_
[
0
][
j
].
dims
();
for
(
size_t
i
=
1
;
i
<
buffer_
.
size
();
++
i
)
{
std
::
type_index
ins_type
=
buffer_
[
i
][
j
].
type
();
framework
::
DDim
ins_shape
=
buffer_
[
i
][
j
].
dims
();
PADDLE_ENFORCE_EQ
(
batch_type
,
ins_type
);
PADDLE_ENFORCE_EQ
(
slice_ddim
(
batch_shape
,
1
,
batch_shape
.
size
()),
slice_ddim
(
ins_shape
,
1
,
ins_shape
.
size
()));
PADDLE_ENFORCE_GT
(
ins_shape
[
0
],
0
);
batch_shape
[
0
]
+=
ins_shape
[
0
];
}
framework
::
LoDTensor
out_tensor
;
out_tensor
.
Resize
(
batch_shape
);
out_tensor
.
mutable_data
(
platform
::
CPUPlace
(),
batch_type
);
int64_t
dst_offset
=
0
;
// Merge lod and data
framework
::
LoD
batch_lod
;
for
(
size_t
i
=
0
;
i
<
buffer_
.
size
();
++
i
)
{
framework
::
DDim
ins_shape
=
buffer_
[
i
][
j
].
dims
();
framework
::
LoD
ins_lod
=
buffer_
[
i
][
j
].
lod
();
if
(
i
==
0
)
{
batch_lod
=
ins_lod
;
}
else
{
PADDLE_ENFORCE_EQ
(
batch_lod
.
size
(),
ins_lod
.
size
());
for
(
size_t
level_idx
=
0
;
level_idx
<
batch_lod
.
size
();
++
level_idx
)
{
auto
&
lod_level
=
batch_lod
[
level_idx
];
for
(
size_t
k
=
1
;
k
<
ins_lod
[
level_idx
].
size
();
++
k
)
{
lod_level
.
push_back
(
ins_lod
[
level_idx
][
k
]
+
lod_level
.
back
());
}
}
}
auto
dst
=
out_tensor
.
Slice
(
dst_offset
,
dst_offset
+
ins_shape
[
0
]);
TensorCopy
(
buffer_
[
i
][
j
],
platform
::
CPUPlace
(),
&
dst
);
dst_offset
+=
ins_shape
[
0
];
}
out_tensor
.
set_lod
(
batch_lod
);
out
->
push_back
(
out_tensor
);
}
}
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
::
reader
;
REGISTER_DECORATED_READER_OPERATOR
(
create_batch_reader
,
ops
::
CreateBatchReaderOp
,
ops
::
CreateBatchReaderOpMaker
);
paddle/fluid/operators/reader/create_random_data_generator_op.cc
0 → 100644
浏览文件 @
42e65a20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
template
<
typename
T
>
class
RandomDataGenerator
:
public
framework
::
FileReader
{
public:
RandomDataGenerator
(
const
std
::
vector
<
framework
::
DDim
>&
shapes
,
float
min
,
float
max
)
:
FileReader
(
shapes
),
min_
(
min
),
max_
(
max
)
{
PADDLE_ENFORCE_LE
(
min
,
max
,
"'min' shouldn't be greater than 'max'.(%f vs %f)"
,
min
,
max
);
unsigned
int
seed
=
std
::
random_device
()();
engine_
.
seed
(
seed
);
dist_
=
std
::
uniform_real_distribution
<
float
>
(
min_
,
max_
);
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
out
->
clear
();
out
->
reserve
(
shapes_
.
size
());
for
(
const
framework
::
DDim
&
shape
:
shapes_
)
{
PADDLE_ENFORCE_GE
(
shape
.
size
(),
2
,
"The rank of reader's output data should be 2 at least.(Now it's %d)"
,
shape
.
size
());
framework
::
LoDTensor
out_tensor
;
out_tensor
.
Resize
(
shape
);
T
*
data
=
out_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
int64_t
numel
=
framework
::
product
(
shape
);
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
data
[
i
]
=
dist_
(
engine_
);
}
out
->
push_back
(
out_tensor
);
}
}
bool
HasNext
()
const
override
{
return
true
;
}
void
ReInit
()
override
{
return
;
}
private:
float
min_
;
float
max_
;
std
::
minstd_rand
engine_
;
std
::
uniform_real_distribution
<
float
>
dist_
;
};
template
<
typename
T
>
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RandomDataGenerator
<
T
>
(
shapes
,
Attr
<
float
>
(
"min"
),
Attr
<
float
>
(
"max"
)));
}
};
class
CreateRandomDataGeneratorOpMaker
:
public
FileReaderMakerBase
{
public:
CreateRandomDataGeneratorOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
FileReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
float
>
(
"min"
,
"The lower bound of reader's uniform distribution."
);
AddAttr
<
float
>
(
"max"
,
"The upper bound of reader's uniform distribution."
);
AddComment
(
R"DOC(
CreateRandomDataGenerator Operator
This Op creates a random reader.
The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
::
reader
;
REGISTER_FILE_READER_OPERATOR
(
create_random_data_generator
,
ops
::
CreateRandomDataGeneratorOp
<
float
>
,
ops
::
CreateRandomDataGeneratorOpMaker
);
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
0 → 100644
浏览文件 @
42e65a20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
ShuffleReader
:
public
framework
::
DecoratedReader
{
public:
ShuffleReader
(
ReaderBase
*
reader
,
int
buffer_size
)
:
DecoratedReader
(
reader
),
buffer_size_
(
buffer_size
),
iteration_pos_
(
0
)
{
buffer_
.
reserve
(
buffer_size
);
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
int
buffer_size_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
size_t
iteration_pos_
;
};
void
ShuffleReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
// Reload buffer with new data
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
for
(
int
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
reader_
->
HasNext
())
{
buffer_
.
push_back
(
std
::
vector
<
framework
::
LoDTensor
>
());
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
else
{
break
;
}
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std
::
random_shuffle
(
buffer_
.
begin
(),
buffer_
.
end
());
iteration_pos_
=
0
;
}
out
->
clear
();
if
(
!
buffer_
.
empty
())
{
std
::
swap
(
*
out
,
buffer_
[
iteration_pos_
++
]);
}
// if buffer_ is empty, the 'out' will return as an empty vector.
}
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"buffer_size"
)));
}
};
class
CreateShuffleReaderOpMaker
:
public
DecoratedReaderMakerBase
{
public:
CreateShuffleReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
DecoratedReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
int
>
(
"buffer_size"
,
"The shuffle buffer size."
).
GreaterThan
(
0
);
AddComment
(
R"DOC(
CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader'
and yields the underlying reader's outputs in a shuffled order.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
::
reader
;
REGISTER_DECORATED_READER_OPERATOR
(
create_shuffle_reader
,
ops
::
CreateShuffleReaderOp
,
ops
::
CreateShuffleReaderOpMaker
);
paddle/fluid/operators/reader/reader_op_registry.cc
0 → 100644
浏览文件 @
42e65a20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
)
{
std
::
vector
<
framework
::
DDim
>
res
;
int
offset
=
0
;
for
(
int
len
:
ranks
)
{
auto
start_it
=
shape_concat
.
begin
()
+
offset
;
auto
end_it
=
start_it
+
len
;
res
.
push_back
(
framework
::
make_ddim
(
std
::
vector
<
int
>
(
start_it
,
end_it
)));
offset
+=
len
;
}
return
res
;
}
FileReaderMakerBase
::
FileReaderMakerBase
(
framework
::
OpProtoAndCheckerMaker
::
OpProto
*
op_proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"(ReaderHolder) The created random reader."
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape_concat"
,
"The concat of all data's shapes."
);
AddAttr
<
std
::
vector
<
int
>>
(
"ranks"
,
"The ranks of each data."
"e.g."
"shape_concat = [2,3,4,5,6]"
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."
);
AddAttr
<
std
::
vector
<
int
>>
(
"lod_levels"
,
"The LoD levels of each data."
);
}
void
FileReaderInferShape
::
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output file reader should not be null."
);
const
auto
shape_concat
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
ranks
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ranks"
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
ctx
->
SetReaderDims
(
"Out"
,
shapes
);
if
(
ctx
->
IsRuntime
())
{
const
auto
lod_levels
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"lod_levels"
);
PADDLE_ENFORCE_EQ
(
lod_levels
.
size
(),
shapes
.
size
(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d)."
,
lod_levels
.
size
(),
shapes
.
size
());
framework
::
VarDesc
*
reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
reader
->
SetLoDLevels
(
lod_levels
);
}
}
void
FileReaderInferVarType
::
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
{
std
::
string
reader_name
=
op_desc
.
Output
(
"Out"
)[
0
];
framework
::
VarDesc
*
reader
=
block
->
FindVarRecursive
(
reader_name
);
reader
->
SetType
(
framework
::
proto
::
VarType
::
READER
);
}
void
DecoratedReaderInferShape
::
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"UnderlyingReader"
),
"Input(UnderlyingReader) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output decorated reader should not be null."
);
ctx
->
SetReaderDims
(
"Out"
,
ctx
->
GetReaderDims
(
"UnderlyingReader"
));
if
(
ctx
->
IsRuntime
())
{
framework
::
VarDesc
*
in_reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetInputVarPtrs
(
"UnderlyingReader"
)[
0
]);
framework
::
VarDesc
*
out_reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
out_reader
->
SetLoDLevels
(
in_reader
->
GetLoDLevels
());
}
}
void
DecoratedReaderInferVarType
::
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
{
std
::
string
in_reader_name
=
op_desc
.
Input
(
"UnderlyingReader"
)[
0
];
framework
::
VarDesc
*
in_reader
=
block
->
FindVarRecursive
(
in_reader_name
);
std
::
string
out_reader_name
=
op_desc
.
Output
(
"Out"
)[
0
];
framework
::
VarDesc
*
out_reader
=
block
->
FindVarRecursive
(
out_reader_name
);
out_reader
->
SetType
(
framework
::
proto
::
VarType
::
READER
);
out_reader
->
SetDataTypes
(
in_reader
->
GetDataTypes
());
}
DecoratedReaderMakerBase
::
DecoratedReaderMakerBase
(
framework
::
OpProtoAndCheckerMaker
::
OpProto
*
op_proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddInput
(
"UnderlyingReader"
,
"(ReaderHolder) The underlying reader for creating a batch reader."
);
AddOutput
(
"Out"
,
"(ReaderHolder) The created batch reader."
);
}
}
// namespace reader
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reader/reader_op_registry.h
0 → 100644
浏览文件 @
42e65a20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
extern
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
);
class
FileReaderMakerBase
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
FileReaderMakerBase
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
);
};
class
FileReaderInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
class
FileReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
;
};
// general infershape for decorated reader
class
DecoratedReaderInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
// general var type inference for decorated reader
class
DecoratedReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
;
};
class
DecoratedReaderMakerBase
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
DecoratedReaderMakerBase
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
);
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
#define REGISTER_FILE_READER_OPERATOR(op_name, ...) \
REGISTER_OPERATOR(op_name, __VA_ARGS__, \
paddle::operators::reader::FileReaderInferShape, \
paddle::framework::EmptyGradOpMaker, \
paddle::operators::reader::FileReaderInferVarType)
#define REGISTER_DECORATED_READER_OPERATOR(op_name, ...) \
REGISTER_OPERATOR(op_name, __VA_ARGS__, \
paddle::operators::reader::DecoratedReaderInferShape, \
paddle::framework::EmptyGradOpMaker, \
paddle::operators::reader::DecoratedReaderInferVarType)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录