Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
175aa7ea
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
175aa7ea
编写于
2月 10, 2018
作者:
F
fengjiayi
提交者:
Yi Wang
2月 09, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add lod and dtype inference (#8329)
上级
74492d5d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
89 addition
and
19 deletion
+89
-19
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+7
-0
paddle/framework/operator.cc
paddle/framework/operator.cc
+4
-0
paddle/framework/reader.cc
paddle/framework/reader.cc
+0
-6
paddle/framework/shape_inference.cc
paddle/framework/shape_inference.cc
+22
-0
paddle/framework/shape_inference.h
paddle/framework/shape_inference.h
+10
-0
paddle/operators/create_reader_op.cc
paddle/operators/create_reader_op.cc
+40
-5
python/paddle/v2/fluid/tests/test_cpp_reader.py
python/paddle/v2/fluid/tests/test_cpp_reader.py
+6
-8
未找到文件。
paddle/framework/op_desc.cc
浏览文件 @
175aa7ea
...
...
@@ -77,6 +77,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
override
;
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
override
;
const
OpDesc
&
op_
;
const
BlockDesc
&
block_
;
};
...
...
@@ -510,5 +512,10 @@ proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
return
block_
.
FindVarRecursive
(
name
)
->
GetType
();
}
InferShapeVarPtr
CompileTimeInferShapeContext
::
GetVarPtr
(
const
std
::
string
&
name
)
{
return
block_
.
FindVarRecursive
(
name
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/operator.cc
浏览文件 @
175aa7ea
...
...
@@ -470,6 +470,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
return
ToVarType
(
var
->
Type
());
}
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
override
{
return
scope_
.
FindVar
(
name
);
}
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
...
...
paddle/framework/reader.cc
浏览文件 @
175aa7ea
...
...
@@ -90,7 +90,6 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
// Merge lod and data
LoD
batch_lod
;
std
::
vector
<
size_t
>
top_level_lod
({
0
});
for
(
size_t
i
=
0
;
i
<
buffer_
.
size
();
++
i
)
{
DDim
ins_shape
=
buffer_
[
i
][
j
].
dims
();
LoD
ins_lod
=
buffer_
[
i
][
j
].
lod
();
...
...
@@ -105,15 +104,10 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
}
}
}
top_level_lod
.
push_back
(
top_level_lod
.
back
()
+
(
ins_lod
.
empty
()
?
ins_shape
[
0
]
:
(
ins_lod
[
0
].
size
()
-
1
)));
Tensor
dst
=
out_tensor
.
Slice
(
dst_offset
,
dst_offset
+
ins_shape
[
0
]);
Copy
(
buffer_
[
i
][
j
],
platform
::
CPUPlace
(),
&
dst
);
dst_offset
+=
ins_shape
[
0
];
}
batch_lod
.
insert
(
batch_lod
.
begin
(),
top_level_lod
);
out_tensor
.
set_lod
(
batch_lod
);
out
->
push_back
(
out_tensor
);
}
...
...
paddle/framework/shape_inference.cc
浏览文件 @
175aa7ea
...
...
@@ -72,6 +72,28 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return
this
->
SetRepeatedDims
(
arg_names
[
0
],
dims
);
}
std
::
vector
<
InferShapeVarPtr
>
InferShapeContext
::
GetInputVarPtrs
(
const
std
::
string
&
name
)
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Inputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetVarPtr
(
name
);
});
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
InferShapeContext
::
GetOutputVarPtrs
(
const
std
::
string
&
name
)
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Outputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetVarPtr
(
name
);
});
return
res
;
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
DDim
>
ret
;
...
...
paddle/framework/shape_inference.h
浏览文件 @
175aa7ea
...
...
@@ -17,10 +17,14 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/var_desc.h"
#include "paddle/framework/variable.h"
namespace
paddle
{
namespace
framework
{
using
InferShapeVarPtr
=
boost
::
variant
<
VarDesc
*
,
Variable
*>
;
class
InferShapeContext
{
public:
virtual
~
InferShapeContext
()
=
default
;
...
...
@@ -55,6 +59,9 @@ class InferShapeContext {
virtual
bool
IsRuntime
()
const
=
0
;
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
);
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
);
// Note: In while op, we need this to be public
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
DDim
>
&
dims
);
...
...
@@ -67,10 +74,13 @@ class InferShapeContext {
const
std
::
vector
<
DDim
>
&
dims
)
=
0
;
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
std
::
vector
<
proto
::
VarDesc
::
VarType
>
GetVarTypes
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
virtual
proto
::
VarDesc
::
VarType
GetVarType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
=
0
;
};
}
// namespace framework
...
...
paddle/operators/create_reader_op.cc
浏览文件 @
175aa7ea
...
...
@@ -42,6 +42,18 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
const
auto
ranks
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ranks"
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
ctx
->
SetReaderDims
(
"Out"
,
shapes
);
if
(
ctx
->
IsRuntime
())
{
const
auto
lod_levels
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"lod_levels"
);
PADDLE_ENFORCE_EQ
(
lod_levels
.
size
(),
shapes
.
size
(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d)."
,
lod_levels
.
size
(),
shapes
.
size
());
framework
::
VarDesc
*
reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
reader
->
SetLoDLevels
(
lod_levels
);
}
}
};
...
...
@@ -54,11 +66,19 @@ class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output decorated reader should not be null."
);
ctx
->
SetReaderDims
(
"Out"
,
ctx
->
GetReaderDims
(
"UnderlyingReader"
));
if
(
ctx
->
IsRuntime
())
{
framework
::
VarDesc
*
in_reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetInputVarPtrs
(
"UnderlyingReader"
)[
0
]);
framework
::
VarDesc
*
out_reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
out_reader
->
SetLoDLevels
(
in_reader
->
GetLoDLevels
());
}
}
};
// general var type inference for
all
readers
class
CreateReaderInferVarType
:
public
framework
::
VarTypeInference
{
// general var type inference for
file
readers
class
Create
File
ReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
...
...
@@ -68,6 +88,20 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
}
};
// general var type inference for decorated readers
class
CreateDecoratedReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
string
in_reader_name
=
op_desc
.
Input
(
"UnderlyingReader"
)[
0
];
framework
::
VarDesc
*
in_reader
=
block
->
FindVarRecursive
(
in_reader_name
);
std
::
string
out_reader_name
=
op_desc
.
Output
(
"Out"
)[
0
];
framework
::
VarDesc
*
out_reader
=
block
->
FindVarRecursive
(
out_reader_name
);
out_reader
->
SetType
(
framework
::
proto
::
VarDesc
::
READER
);
out_reader
->
SetDataTypes
(
in_reader
->
GetDataTypes
());
}
};
template
<
typename
T
>
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
public:
...
...
@@ -105,6 +139,7 @@ class CreateRandomDataGeneratorOpMaker
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."
);
AddAttr
<
std
::
vector
<
int
>>
(
"lod_levels"
,
"The LoD levels of each data."
);
AddAttr
<
float
>
(
"min"
,
"The lower bound of reader's uniform distribution."
);
AddAttr
<
float
>
(
"max"
,
"The upper bound of reader's uniform distribution."
);
AddComment
(
R"DOC(
...
...
@@ -192,14 +227,14 @@ REGISTER_OPERATOR(create_random_data_generator,
ops
::
CreateFileReaderInferShape
,
ops
::
CreateRandomDataGeneratorOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateReaderInferVarType
);
ops
::
Create
File
ReaderInferVarType
);
REGISTER_OPERATOR
(
create_shuffle_reader
,
ops
::
CreateShuffleReaderOp
,
ops
::
CreateDecoratedReaderInferShape
,
ops
::
CreateShuffleReaderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateReaderInferVarType
);
ops
::
Create
Decorated
ReaderInferVarType
);
REGISTER_OPERATOR
(
create_batch_reader
,
ops
::
CreateBatchReaderOp
,
ops
::
CreateDecoratedReaderInferShape
,
ops
::
CreateBatchReaderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateReaderInferVarType
);
ops
::
Create
Decorated
ReaderInferVarType
);
python/paddle/v2/fluid/tests/test_cpp_reader.py
浏览文件 @
175aa7ea
...
...
@@ -21,7 +21,8 @@ block = prog.current_block()
random_reader
=
block
.
create_var
(
type
=
fluid
.
core
.
VarDesc
.
VarType
.
READER
,
name
=
"RandomDataGenerator"
)
random_reader
.
desc
.
set_lod_levels
([
0
,
0
])
random_reader
.
desc
.
set_dtypes
(
[
fluid
.
core
.
DataType
.
FP32
,
fluid
.
core
.
DataType
.
FP32
])
create_random_data_generator_op
=
block
.
append_op
(
type
=
"create_random_data_generator"
,
...
...
@@ -30,11 +31,11 @@ create_random_data_generator_op = block.append_op(
"shape_concat"
:
[
1
,
2
,
1
,
1
],
"ranks"
:
[
2
,
2
],
"min"
:
0.0
,
"max"
:
1.0
"max"
:
1.0
,
'lod_levels'
:
[
0
,
0
]
})
shuffle_reader
=
block
.
create_var
(
type
=
fluid
.
core
.
VarDesc
.
VarType
.
READER
,
name
=
"ShuffleReader"
)
shuffle_reader
.
desc
.
set_lod_levels
([
0
,
0
])
create_shuffle_reader_op
=
block
.
append_op
(
type
=
"create_shuffle_reader"
,
...
...
@@ -44,7 +45,6 @@ create_shuffle_reader_op = block.append_op(
batch_reader
=
block
.
create_var
(
type
=
fluid
.
core
.
VarDesc
.
VarType
.
READER
,
name
=
"BatchReader"
)
batch_reader
.
desc
.
set_lod_levels
([
1
,
1
])
create_batch_reader_op
=
block
.
append_op
(
type
=
"create_batch_reader"
,
...
...
@@ -62,11 +62,9 @@ read_op = block.append_op(
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
[
res1
,
res2
]
=
exe
.
run
(
prog
,
fetch_list
=
[
out1
,
out2
]
,
return_numpy
=
False
)
[
res1
,
res2
]
=
exe
.
run
(
prog
,
fetch_list
=
[
out1
,
out2
])
test_pass
=
res1
.
lod
()
==
[
range
(
0
,
11
)]
and
res1
.
lod
()
==
[
range
(
0
,
11
)
]
and
np
.
array
(
res1
).
shape
==
(
10
,
2
)
and
np
.
array
(
res2
).
shape
==
(
10
,
1
)
test_pass
=
res1
.
shape
==
(
10
,
2
)
and
res2
.
shape
==
(
10
,
1
)
if
not
test_pass
:
exit
(
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录