Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1010e39b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1010e39b
编写于
2月 06, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ReadOp
上级
6e6f5c7e
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
193 addition
and
49 deletion
+193
-49
paddle/framework/framework.proto
paddle/framework/framework.proto
+2
-2
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+23
-6
paddle/framework/operator.cc
paddle/framework/operator.cc
+20
-6
paddle/framework/reader.cc
paddle/framework/reader.cc
+22
-18
paddle/framework/reader.h
paddle/framework/reader.h
+16
-16
paddle/framework/shape_inference.cc
paddle/framework/shape_inference.cc
+14
-0
paddle/framework/shape_inference.h
paddle/framework/shape_inference.h
+2
-1
paddle/operators/read_op.cc
paddle/operators/read_op.cc
+94
-0
未找到文件。
paddle/framework/framework.proto
浏览文件 @
1010e39b
...
@@ -116,7 +116,7 @@ message LoDTensorArrayDesc {
...
@@ -116,7 +116,7 @@ message LoDTensorArrayDesc {
optional
int32
lod_level
=
2
[
default
=
0
];
optional
int32
lod_level
=
2
[
default
=
0
];
}
}
message
Reader
{
repeated
LoDTensorDesc
lod_tensor
=
1
;
}
message
Reader
Desc
{
repeated
LoDTensorDesc
lod_tensor
=
1
;
}
message
VarDesc
{
message
VarDesc
{
enum
VarType
{
enum
VarType
{
...
@@ -136,7 +136,7 @@ message VarDesc {
...
@@ -136,7 +136,7 @@ message VarDesc {
optional
LoDTensorDesc
lod_tensor
=
4
;
optional
LoDTensorDesc
lod_tensor
=
4
;
optional
TensorDesc
selected_rows
=
5
;
optional
TensorDesc
selected_rows
=
5
;
optional
LoDTensorArrayDesc
tensor_array
=
6
;
optional
LoDTensorArrayDesc
tensor_array
=
6
;
optional
Reader
reader
=
7
;
optional
Reader
Desc
reader
=
7
;
}
}
message
BlockDesc
{
message
BlockDesc
{
...
...
paddle/framework/op_desc.cc
浏览文件 @
1010e39b
...
@@ -72,6 +72,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
...
@@ -72,6 +72,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
std
::
vector
<
DDim
>
GetRepeatedDim
(
const
std
::
string
&
name
)
const
override
;
const
OpDesc
&
op_
;
const
OpDesc
&
op_
;
const
BlockDesc
&
block_
;
const
BlockDesc
&
block_
;
};
};
...
@@ -457,22 +459,37 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
...
@@ -457,22 +459,37 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
DDim
CompileTimeInferShapeContext
::
GetDim
(
const
std
::
string
&
name
)
const
{
DDim
CompileTimeInferShapeContext
::
GetDim
(
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
DDim
res
;
try
{
try
{
auto
shape
=
var
->
GetShape
();
auto
shape
=
var
->
GetShape
();
if
(
shape
.
empty
())
{
res
=
shape
.
empty
()
?
make_ddim
({
0UL
})
:
make_ddim
(
shape
);
return
framework
::
make_ddim
({
0UL
});
}
else
{
return
framework
::
make_ddim
(
var
->
GetShape
());
}
}
catch
(...)
{
}
catch
(...)
{
VLOG
(
5
)
<<
"GetDim of variable "
<<
name
<<
" error"
;
VLOG
(
5
)
<<
"GetDim of variable "
<<
name
<<
" error"
;
std
::
rethrow_exception
(
std
::
current_exception
());
std
::
rethrow_exception
(
std
::
current_exception
());
}
}
return
res
;
}
std
::
vector
<
DDim
>
CompileTimeInferShapeContext
::
GetRepeatedDim
(
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
std
::
vector
<
DDim
>
res
;
try
{
auto
shapes
=
var
->
GetShapes
();
for
(
const
auto
&
s
:
shapes
)
{
res
.
push_back
(
s
.
empty
()
?
make_ddim
({
0UL
})
:
make_ddim
(
s
));
}
}
catch
(...)
{
VLOG
(
5
)
<<
"GetRepeatedDim of variable "
<<
name
<<
" error."
;
std
::
rethrow_exception
(
std
::
current_exception
());
}
return
res
;
}
}
void
CompileTimeInferShapeContext
::
SetDim
(
const
std
::
string
&
name
,
void
CompileTimeInferShapeContext
::
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
{
const
DDim
&
dim
)
{
block_
.
FindVarRecursive
(
name
)
->
SetShape
(
framework
::
vectorize
(
dim
));
block_
.
FindVarRecursive
(
name
)
->
SetShape
(
vectorize
(
dim
));
}
}
bool
CompileTimeInferShapeContext
::
IsRuntime
()
const
{
return
false
;
}
bool
CompileTimeInferShapeContext
::
IsRuntime
()
const
{
return
false
;
}
...
...
paddle/framework/operator.cc
浏览文件 @
1010e39b
...
@@ -320,8 +320,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -320,8 +320,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
if
(
length
==
0
)
{
if
(
length
==
0
)
{
return
false
;
return
false
;
}
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input %s should have more than one inputs"
,
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
name
);
"Input %s should not have more than one inputs"
,
name
);
auto
ipt
=
ins
[
0
];
auto
ipt
=
ins
[
0
];
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
return
var
!=
nullptr
;
return
var
!=
nullptr
;
...
@@ -333,8 +333,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -333,8 +333,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
if
(
length
==
0
)
{
if
(
length
==
0
)
{
return
false
;
return
false
;
}
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Output %s should have more than one inputs"
,
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
name
);
"Output %s should not have more than one inputs"
,
name
);
auto
ipt
=
outs
[
0
];
auto
ipt
=
outs
[
0
];
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
return
var
!=
nullptr
;
return
var
!=
nullptr
;
...
@@ -421,8 +421,22 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -421,8 +421,22 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
}
else
{
}
else
{
PADDLE_THROW
(
"Variable %s type_id %s, expect LoDTensor/SelectedRows."
,
PADDLE_THROW
(
name
,
var
->
Type
().
name
());
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
"type_id is %s."
,
name
,
var
->
Type
().
name
());
}
}
std
::
vector
<
DDim
>
GetRepeatedDim
(
const
std
::
string
&
name
)
const
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
ReaderHolder
>
())
{
return
var
->
Get
<
ReaderHolder
>
().
shapes
();
}
else
{
PADDLE_THROW
(
"Only ReaderHolder support 'GetRepeatedDim', but Variable %s's "
"type_id is %s."
,
name
,
var
->
Type
().
name
());
}
}
}
}
...
...
paddle/framework/reader.cc
浏览文件 @
1010e39b
...
@@ -25,13 +25,15 @@ DDim FileReader::shape(size_t idx) const {
...
@@ -25,13 +25,15 @@ DDim FileReader::shape(size_t idx) const {
return
shapes_
[
idx
];
return
shapes_
[
idx
];
}
}
std
::
vector
<
LoDTensor
>
ShuffleReader
::
ReadNext
(
)
{
void
ShuffleReader
::
ReadNext
(
std
::
vector
<
LoDtensor
>*
out
)
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
// Reload buffer with new data
// Reload buffer with new data
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reverse
(
buffer_size_
);
for
(
int
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
reader_
->
HasNext
())
{
if
(
reader_
->
HasNext
())
{
buffer_
.
push_back
(
reader_
->
ReadNext
());
buffer
.
push_back
(
std
::
vector
<
LoDTensor
>
());
reader_
->
ReadNext
(
&
buffer
.
back
());
}
else
{
}
else
{
break
;
break
;
}
}
...
@@ -39,29 +41,32 @@ std::vector<LoDTensor> ShuffleReader::ReadNext() {
...
@@ -39,29 +41,32 @@ std::vector<LoDTensor> ShuffleReader::ReadNext() {
std
::
random_shuffle
(
buffer_
.
begin
(),
buffer_
.
end
());
std
::
random_shuffle
(
buffer_
.
begin
(),
buffer_
.
end
());
iteration_pos_
=
0
;
iteration_pos_
=
0
;
}
}
if
(
buffer_
.
empty
())
{
out
->
clear
();
std
::
vector
<
LoDTensor
>
empty_res
;
if
(
!
buffer_
.
empty
())
{
return
empty_res
;
std
::
swap
(
*
out
,
buffer_
[
iteration_pos_
++
])
;
}
}
return
buffer_
[
iteration_pos_
++
];
// if buffer_ is empty, the 'out' will return as an empty vector.
}
}
std
::
vector
<
LoDTensor
>
BatchReader
::
ReadNext
(
)
{
void
BatchReader
::
ReadNext
(
std
::
vector
<
LoDtensor
>*
out
)
{
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reserve
(
batch_size_
);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
if
(
reader_
->
HasNext
())
{
if
(
reader_
->
HasNext
())
{
buffer_
.
push_back
(
reader_
->
ReadNext
());
buffer_
.
push_back
(
std
::
vector
<
LoDtensor
>
());
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
else
{
}
else
{
break
;
break
;
}
}
}
}
// Concat instances
// Concat instances
std
::
vector
<
LoDTensor
>
res
;
out
.
clear
()
;
if
(
buffer_
.
empty
())
{
if
(
buffer_
.
empty
())
{
return
res
;
// if buffer_ is empty, the 'out' will return as an empty vector.
return
;
}
}
int
out_num
=
buffer_
[
0
].
size
();
int
out_num
=
buffer_
[
0
].
size
();
res
.
reserve
(
out_num
);
out
->
reserve
(
out_num
);
for
(
int
j
=
0
;
j
<
out_num
;
++
j
)
{
for
(
int
j
=
0
;
j
<
out_num
;
++
j
)
{
// Merge shape and check date type
// Merge shape and check date type
std
::
type_index
batch_type
=
buffer_
[
0
][
j
].
type
();
std
::
type_index
batch_type
=
buffer_
[
0
][
j
].
type
();
...
@@ -76,9 +81,9 @@ std::vector<LoDTensor> BatchReader::ReadNext() {
...
@@ -76,9 +81,9 @@ std::vector<LoDTensor> BatchReader::ReadNext() {
batch_shape
[
0
]
+=
ins_shape
[
0
];
batch_shape
[
0
]
+=
ins_shape
[
0
];
}
}
LoDTensor
out
;
LoDTensor
out
_tensor
;
out
.
Resize
(
batch_shape
);
out
_tensor
.
Resize
(
batch_shape
);
out
.
mutable_data
(
platform
::
CPUPlace
(),
batch_type
);
out
_tensor
.
mutable_data
(
platform
::
CPUPlace
(),
batch_type
);
int64_t
dst_offset
=
0
;
int64_t
dst_offset
=
0
;
// Merge lod and data
// Merge lod and data
...
@@ -102,15 +107,14 @@ std::vector<LoDTensor> BatchReader::ReadNext() {
...
@@ -102,15 +107,14 @@ std::vector<LoDTensor> BatchReader::ReadNext() {
top_level_lod
.
back
()
+
top_level_lod
.
back
()
+
(
ins_lod
.
empty
()
?
ins_shape
[
0
]
:
(
ins_lod
[
0
].
size
()
-
1
)));
(
ins_lod
.
empty
()
?
ins_shape
[
0
]
:
(
ins_lod
[
0
].
size
()
-
1
)));
Tensor
dst
=
out
.
Slice
(
dst_offset
,
dst_offset
+
ins_shape
[
0
]);
Tensor
dst
=
out
_tensor
.
Slice
(
dst_offset
,
dst_offset
+
ins_shape
[
0
]);
Copy
(
buffer_
[
i
][
j
],
platform
::
CPUPlace
(),
&
dst
);
Copy
(
buffer_
[
i
][
j
],
platform
::
CPUPlace
(),
&
dst
);
dst_offset
+=
ins_shape
[
0
];
dst_offset
+=
ins_shape
[
0
];
}
}
batch_lod
.
insert
(
batch_lod
.
begin
(),
top_level_lod
);
batch_lod
.
insert
(
batch_lod
.
begin
(),
top_level_lod
);
out
.
set_lod
(
batch_lod
);
out
_tensor
.
set_lod
(
batch_lod
);
res
.
push_back
(
out
);
out
->
push_back
(
out_tensor
);
}
}
return
res
;
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/reader.h
浏览文件 @
1010e39b
...
@@ -15,14 +15,14 @@
...
@@ -15,14 +15,14 @@
#pragma once
#pragma once
#include "paddle/framework/ddim.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor
_array
.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
ReaderBase
{
class
ReaderBase
{
public:
public:
virtual
std
::
vector
<
LoDTensor
>
ReadNext
(
)
=
0
;
virtual
void
ReadNext
(
std
::
vector
<
LoDtensor
>*
out
)
=
0
;
virtual
bool
HasNext
()
const
=
0
;
virtual
bool
HasNext
()
const
=
0
;
virtual
DDim
shape
(
size_t
idx
)
const
=
0
;
virtual
DDim
shape
(
size_t
idx
)
const
=
0
;
...
@@ -73,24 +73,24 @@ class RandomReader : public FileReader {
...
@@ -73,24 +73,24 @@ class RandomReader : public FileReader {
dist_
=
std
::
uniform_real_distribution
<
float
>
(
min_
,
max_
);
dist_
=
std
::
uniform_real_distribution
<
float
>
(
min_
,
max_
);
}
}
std
::
vector
<
LoDTensor
>
ReadNext
(
)
override
{
void
ReadNext
(
std
::
vector
<
LoDtensor
>*
out
)
override
{
std
::
vector
<
LoDTensor
>
res
;
out
.
clear
()
;
res
.
reserve
(
shapes_
.
size
());
out
.
reserve
(
shapes_
.
size
());
for
(
const
DDim
&
shape
:
shapes_
)
{
for
(
const
DDim
&
shape
:
shapes_
)
{
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
shape
.
size
(),
2
,
shape
.
size
(),
2
,
"The rank of
in
put data should be 2 at least.(Now it's %d)"
,
"The rank of
reader's out
put data should be 2 at least.(Now it's %d)"
,
shape
.
size
());
shape
.
size
());
LoDTensor
out
;
LoDTensor
out
_tensor
;
out
.
Resize
(
shape
);
out
_tensor
.
Resize
(
shape
);
T
*
data
=
out
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
*
data
=
out
_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
int64_t
numel
=
product
(
shape
);
int64_t
numel
=
product
(
shape
);
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
data
[
i
]
=
dist_
(
engine_
);
data
[
i
]
=
dist_
(
engine_
);
}
}
res
.
push_back
(
out
);
out
.
push_back
(
out_tensor
);
}
}
return
res
;
return
out
;
}
}
bool
HasNext
()
const
override
{
return
true
;
}
bool
HasNext
()
const
override
{
return
true
;
}
...
@@ -111,11 +111,11 @@ class ShuffleReader : public DecoratedReader {
...
@@ -111,11 +111,11 @@ class ShuffleReader : public DecoratedReader {
buffer_
.
reserve
(
buffer_size
);
buffer_
.
reserve
(
buffer_size
);
}
}
std
::
vector
<
LoDTensor
>
ReadNext
(
)
override
;
void
ReadNext
(
std
::
vector
<
LoDtensor
>*
out
)
override
;
private:
private:
int
buffer_size_
;
int
buffer_size_
;
std
::
vector
<
std
::
vector
<
LoD
T
ensor
>>
buffer_
;
std
::
vector
<
std
::
vector
<
LoD
t
ensor
>>
buffer_
;
size_t
iteration_pos_
;
size_t
iteration_pos_
;
};
};
...
@@ -126,11 +126,11 @@ class BatchReader : public DecoratedReader {
...
@@ -126,11 +126,11 @@ class BatchReader : public DecoratedReader {
buffer_
.
reserve
(
batch_size_
);
buffer_
.
reserve
(
batch_size_
);
}
}
std
::
vector
<
LoDTensor
>
ReadNext
(
)
override
;
void
ReadNext
(
std
::
vector
<
LoDtensor
>*
out
)
override
;
private:
private:
int
batch_size_
;
int
batch_size_
;
std
::
vector
<
std
::
vector
<
LoD
T
ensor
>>
buffer_
;
std
::
vector
<
std
::
vector
<
LoD
t
ensor
>>
buffer_
;
};
};
// The ReaderHolder is used as readers' unified wrapper,
// The ReaderHolder is used as readers' unified wrapper,
...
@@ -141,7 +141,7 @@ class ReaderHolder {
...
@@ -141,7 +141,7 @@ class ReaderHolder {
ReaderBase
*
Get
()
const
{
return
reader_
.
get
();
}
ReaderBase
*
Get
()
const
{
return
reader_
.
get
();
}
std
::
vector
<
LoDTensor
>
ReadNext
()
{
return
reader_
->
ReadNext
(
);
}
void
ReadNext
(
std
::
vector
<
LoDtensor
>*
out
)
{
reader_
->
ReadNext
(
out
);
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
DDim
shape
(
size_t
idx
)
const
{
return
reader_
->
shape
(
idx
);
}
DDim
shape
(
size_t
idx
)
const
{
return
reader_
->
shape
(
idx
);
}
...
...
paddle/framework/shape_inference.cc
浏览文件 @
1010e39b
...
@@ -32,6 +32,16 @@ std::vector<DDim> InferShapeContext::GetInputsDim(
...
@@ -32,6 +32,16 @@ std::vector<DDim> InferShapeContext::GetInputsDim(
return
GetDims
(
arg_names
);
return
GetDims
(
arg_names
);
}
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetReaderDims
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Reader input '%s' should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
return
this
->
GetRepeatedDims
(
arg_names
[
0
]);
}
DDim
InferShapeContext
::
GetInputsElementDim
(
const
std
::
string
&
name
,
DDim
InferShapeContext
::
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
{
int
idx
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
=
Inputs
(
name
);
const
std
::
vector
<
std
::
string
>
&
names
=
Inputs
(
name
);
...
@@ -61,6 +71,7 @@ std::vector<DDim> InferShapeContext::GetDims(
...
@@ -61,6 +71,7 @@ std::vector<DDim> InferShapeContext::GetDims(
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetDim
(
name
);
});
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetDim
(
name
);
});
return
ret
;
return
ret
;
}
}
void
InferShapeContext
::
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
void
InferShapeContext
::
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
DDim
>
&
dims
)
{
const
std
::
vector
<
DDim
>
&
dims
)
{
size_t
length
=
names
.
size
();
size_t
length
=
names
.
size
();
...
@@ -72,14 +83,17 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
...
@@ -72,14 +83,17 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
SetDim
(
names
[
i
],
dims
[
i
]);
SetDim
(
names
[
i
],
dims
[
i
]);
}
}
}
}
std
::
vector
<
proto
::
VarDesc
::
VarType
>
InferShapeContext
::
GetInputsVarType
(
std
::
vector
<
proto
::
VarDesc
::
VarType
>
InferShapeContext
::
GetInputsVarType
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
return
GetVarTypes
(
Inputs
(
name
));
return
GetVarTypes
(
Inputs
(
name
));
}
}
std
::
vector
<
proto
::
VarDesc
::
VarType
>
InferShapeContext
::
GetOutputsVarType
(
std
::
vector
<
proto
::
VarDesc
::
VarType
>
InferShapeContext
::
GetOutputsVarType
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
return
GetVarTypes
(
Outputs
(
name
));
return
GetVarTypes
(
Outputs
(
name
));
}
}
std
::
vector
<
proto
::
VarDesc
::
VarType
>
InferShapeContext
::
GetVarTypes
(
std
::
vector
<
proto
::
VarDesc
::
VarType
>
InferShapeContext
::
GetVarTypes
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
proto
::
VarDesc
::
VarType
>
retv
;
std
::
vector
<
proto
::
VarDesc
::
VarType
>
retv
;
...
...
paddle/framework/shape_inference.h
浏览文件 @
1010e39b
...
@@ -36,8 +36,8 @@ class InferShapeContext {
...
@@ -36,8 +36,8 @@ class InferShapeContext {
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
;
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
DDim
>
GetReaderDims
(
const
std
::
string
&
name
)
const
DDim
;
DDim
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
;
DDim
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
;
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
);
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
);
...
@@ -61,6 +61,7 @@ class InferShapeContext {
...
@@ -61,6 +61,7 @@ class InferShapeContext {
protected:
protected:
virtual
DDim
GetDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
DDim
GetDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
virtual
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
std
::
vector
<
DDim
>
GetRepeatedDim
(
const
std
::
string
&
name
)
const
=
0
;
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
std
::
vector
<
proto
::
VarDesc
::
VarType
>
GetVarTypes
(
std
::
vector
<
proto
::
VarDesc
::
VarType
>
GetVarTypes
(
...
...
paddle/operators/read_op.cc
0 → 100644
浏览文件 @
1010e39b
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/framework/op_registry.h"
#include "paddle/framework/reader.h"
namespace
paddle
{
namespace
operators
{
class
ReadInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Reader"
),
"The ReadOp must take a reader as input."
);
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
"Out"
),
"The ReadOp should be assigned with output."
);
std
::
vector
<
DDim
>
reader_dims
=
ctx
->
GetReaderDims
(
"Reader"
);
std
::
vector
<
std
::
string
>
out_names
=
ctx
->
Outputs
(
"Out"
);
PADDLE_ENFORCE_EQ
(
reader_dims
.
size
(),
out_names
.
size
(),
"The reader's dim number doesn't match the output number."
);
ctx
->
SetOutputsDim
(
"Out"
,
reader_dims
);
}
};
class
ReadInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
string
reader_name
=
op_desc
.
Input
(
"Reader"
)[
0
];
std
::
vector
<
std
::
string
>
out_names
=
op_desc
.
Output
(
"Out"
);
framework
::
VarDesc
reader
=
block
.
FindVarRecursive
(
reader_name
);
auto
dtypes
=
reader
.
GetDataTypes
();
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
out_names
.
size
());
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
faremwork
::
VarDesc
&
out
=
block
->
FindRecursiveOrCreateVar
(
out_names
[
i
]);
out
.
SetType
(
framework
::
proto
::
DataType
::
LOD_TENSOR
);
out
.
SetDataType
(
dtypes
[
i
]);
}
}
};
class
ReadOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
framework
::
ReaderHolder
&
reader
=
scope
.
FindVar
(
Input
(
"Reader"
))
->
Get
<
ReaderHolder
>
();
if
(
!
reader
.
HasNext
())
{
// what shall we do???
return
;
}
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
.
ReadNext
(
&
ins
);
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
auto
*
out
=
scope
.
FindVar
(
out_arg_names
[
i
])
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
ins
[
i
].
dims
(),
out
->
dims
());
out
->
ShareDataWith
(
ins
[
i
]);
out
->
set_lod
(
ins
[
i
].
lod
());
}
}
};
class
ReadOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
ReadOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddInput
(
"Reader"
,
"(ReaderHolder) The executed reader."
);
AddOutput
(
"Out"
,
"(LoDTensor) The output data."
).
AsDuplicable
();
AddComment
(
R"DOC(
Read Operator
Execute a given reader once and output data.
)DOC"
)
}
};
}
// namespace operators
}
// namespace paddle
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录