Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
e15d616e
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e15d616e
编写于
5月 10, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete the C++ core of 'CustomReader'
上级
e61a38da
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
91 addition
and
17 deletion
+91
-17
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+1
-2
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+1
-0
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+88
-15
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+1
-0
未找到文件。
paddle/fluid/framework/shape_inference.h
浏览文件 @
e15d616e
...
...
@@ -63,6 +63,7 @@ class InferShapeContext {
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
);
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
);
virtual
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
=
0
;
// Note: In while op, we need this to be public
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
...
...
@@ -81,8 +82,6 @@ class InferShapeContext {
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
virtual
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
=
0
;
};
}
// namespace framework
...
...
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
e15d616e
...
...
@@ -23,6 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_threaded_reader_op SRCS create_threaded_reader_op.cc
)
reader_library
(
create_custom_reader_op SRCS create_custom_reader_op.cc
)
cc_test
(
reader_blocking_queue_test SRCS reader_blocking_queue_test.cc
)
# Export local libraries to parent
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
e15d616e
...
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
...
...
@@ -77,29 +78,101 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void
CustomReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
PADDLE_ENFORCE_EQ
(
source_var_names_
.
size
(),
out
->
size
(),
"The size of source_var_names(%d) not equals to the size of 'out'(%d). "
"Each element of 'out' must have its own source var in the CustomReader."
,
source_var_names_
.
size
(),
out
->
size
());
PADDLE_ENFORCE_EQ
(
sink_var_names_
.
size
(),
out
->
size
(),
"The size of sink_var_names(%d) not equals to the size of 'out'(%d). "
"Each element of 'out' must have its own sink var in the CustomReader."
,
sink_var_names_
.
size
(),
out
->
size
());
class
CustomReaderInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
IsRuntime
(),
"'CustomReaderInferShape' should only be invoked during "
"compile time."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output decorated reader should not be null."
);
const
auto
sink_var_names
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
);
std
::
vector
<
std
::
vector
<
int64_t
>>
res_dims
;
std
::
vector
<
int32_t
>
res_lod_levels
;
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
auto
*
sink_var
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetVarPtr
(
var_name
));
PADDLE_ENFORCE_NOT_NULL
(
sink_var
);
res_dims
.
emplace_back
(
sink_var
->
GetShape
());
res_lod_levels
.
push_back
(
sink_var
->
GetLoDLevel
());
}
auto
*
out_reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
out_reader
->
SetShapes
(
res_dims
);
out_reader
->
SetLoDLevels
(
res_lod_levels
);
}
};
class
CustomReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
framework
::
VarDesc
*
out_reader
=
block
->
FindVar
(
op_desc
.
Output
(
"Out"
)[
0
]);
PADDLE_ENFORCE_NOT_NULL
(
out_reader
);
out_reader
->
SetType
(
framework
::
proto
::
VarType
::
READER
);
auto
sink_var_names
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op_desc
.
GetAttr
(
"sink_var_names"
));
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
res_data_types
;
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
framework
::
VarDesc
*
var
=
block
->
FindVar
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
res_data_types
.
emplace_back
(
var
->
GetDataType
());
}
out_reader
->
SetDataTypes
(
res_data_types
);
}
};
void
CustomReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
out
->
clear
();
std
::
vector
<
framework
::
LoDTensor
>
underlying_outs
;
reader_
->
ReadNext
(
&
underlying_outs
);
if
(
underlying_outs
.
empty
())
{
// There is not next data.
return
;
}
PADDLE_ENFORCE
(
source_var_names_
.
size
()
==
underlying_outs
.
size
()
&&
sink_var_names_
.
size
()
==
underlying_outs
.
size
(),
"The size of source_var_names(%d), the size of sink_var_names(%d) and "
"the size of underlying_outs(%d) are not consistent. Each feeding "
"element must have its own source and sink variable."
,
source_var_names_
.
size
(),
sink_var_names_
.
size
(),
underlying_outs
.
size
());
// 1. Copy LoDTensors from underlying reader's output to source variables.
for
(
size_t
i
=
0
;
i
<
source_var_names_
.
size
();
++
i
)
{
const
std
::
string
&
var_name
=
source_var_names_
[
i
];
framework
::
Variable
*
var
=
scope_
.
FindVar
(
var_name
);
framework
::
Variable
*
var
=
scope_
.
FindVar
(
source_var_names_
[
i
]);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"CustomReader's source variable '%s' doesn't exist."
);
framework
::
LoDTensor
*
tensor
=
var
->
GetMutable
<
framework
::
loDtensor
>
();
framework
::
LoDTensor
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
ShareDataWith
(
underlying_outs
[
i
]);
tensor
->
set_lod
(
underlying_outs
[
i
].
lod
());
}
//
TODO(fengjiayi): 将vector中的数据拷贝到sorce_var和sink_var中
//
2. Run the sub-block.
framework
::
Executor
executor
(
dev_place_
);
framework
::
ProgramDesc
*
program
=
sub_block_
.
Program
();
framework
::
Scope
*
exe_scope
=
&
scope_
.
NewScope
();
executor
.
Run
(
*
program
,
exe_scope
,
sub_block_
.
ID
(),
false
/*create_local_scope*/
,
true
);
scope_
.
DeleteScope
(
exe_scope
);
// 3. Copy LoDTensors from sink variables to out.
out
->
resize
(
sink_var_names_
.
size
());
for
(
size_t
i
=
0
;
i
<
sink_var_names_
.
size
();
++
i
)
{
framework
::
Variable
*
var
=
scope_
.
FindVar
(
sink_var_names_
[
i
]);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"CustomReader's sink variable '%s' doesn't exist."
);
const
framework
::
LoDTensor
&
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
(
*
out
)[
i
].
ShareDataWith
(
tensor
);
(
*
out
)[
i
].
set_lod
(
tensor
.
lod
());
}
}
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
::
reader
;
REGISTER_OPERATOR
(
create_custom_reader
,
ops
::
CreateCustomReaderOp
,
ops
::
CreateCustomReaderOpMaker
,
ops
::
CustomReaderInferShape
,
ops
::
CustomReaderInferVarType
,
paddle
::
framework
::
EmptyGradOpMaker
)
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
e15d616e
...
...
@@ -117,6 +117,7 @@ void DecoratedReaderInferShape::operator()(
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
out_reader
->
SetLoDLevels
(
in_reader
->
GetLoDLevels
());
}
void
DecoratedReaderInferVarType
::
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
{
std
::
string
in_reader_name
=
op_desc
.
Input
(
"UnderlyingReader"
)[
0
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录