Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0bb9c80e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0bb9c80e
编写于
2月 06, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code and add unit tests
上级
1010e39b
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
244 addition
and
89 deletion
+244
-89
paddle/framework/executor.cc
paddle/framework/executor.cc
+5
-2
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+15
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+15
-2
paddle/framework/reader.cc
paddle/framework/reader.cc
+8
-8
paddle/framework/reader.h
paddle/framework/reader.h
+25
-26
paddle/framework/shape_inference.cc
paddle/framework/shape_inference.cc
+10
-0
paddle/framework/shape_inference.h
paddle/framework/shape_inference.h
+5
-2
paddle/framework/var_desc.cc
paddle/framework/var_desc.cc
+23
-12
paddle/framework/var_type.h
paddle/framework/var_type.h
+7
-1
paddle/operators/create_reader_op.cc
paddle/operators/create_reader_op.cc
+43
-18
paddle/operators/read_op.cc
paddle/operators/read_op.cc
+15
-13
paddle/pybind/protobuf.cc
paddle/pybind/protobuf.cc
+0
-2
python/paddle/v2/fluid/executor.py
python/paddle/v2/fluid/executor.py
+2
-1
python/paddle/v2/fluid/tests/test_cpp_reader.py
python/paddle/v2/fluid/tests/test_cpp_reader.py
+71
-0
未找到文件。
paddle/framework/executor.cc
浏览文件 @
0bb9c80e
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/reader.h"
#include "paddle/platform/place.h"
#include "paddle/platform/place.h"
#include "paddle/platform/profiler.h"
#include "paddle/platform/profiler.h"
...
@@ -52,11 +53,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
...
@@ -52,11 +53,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var
->
GetMutable
<
LoDTensorArray
>
();
var
->
GetMutable
<
LoDTensorArray
>
();
}
else
if
(
var_type
==
proto
::
VarDesc
::
PLACE_LIST
)
{
}
else
if
(
var_type
==
proto
::
VarDesc
::
PLACE_LIST
)
{
var
->
GetMutable
<
platform
::
PlaceList
>
();
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarDesc
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
"Variable type %d is not in "
"Variable type %d is not in "
"[L
oDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE,
"
"[L
OD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST,
"
"
PLACE_LIST
]"
,
"
LOD_RANK_TABLE, PLACE_LIST, READER
]"
,
var_type
);
var_type
);
}
}
}
}
...
...
paddle/framework/op_desc.cc
浏览文件 @
0bb9c80e
...
@@ -72,7 +72,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
...
@@ -72,7 +72,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
std
::
vector
<
DDim
>
GetRepeatedDim
(
const
std
::
string
&
name
)
const
override
;
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
;
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
override
;
const
OpDesc
&
op_
;
const
OpDesc
&
op_
;
const
BlockDesc
&
block_
;
const
BlockDesc
&
block_
;
...
@@ -470,7 +473,7 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
...
@@ -470,7 +473,7 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
return
res
;
return
res
;
}
}
std
::
vector
<
DDim
>
CompileTimeInferShapeContext
::
GetRepeatedDim
(
std
::
vector
<
DDim
>
CompileTimeInferShapeContext
::
GetRepeatedDim
s
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
...
@@ -491,6 +494,16 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
...
@@ -491,6 +494,16 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
const
DDim
&
dim
)
{
const
DDim
&
dim
)
{
block_
.
FindVarRecursive
(
name
)
->
SetShape
(
vectorize
(
dim
));
block_
.
FindVarRecursive
(
name
)
->
SetShape
(
vectorize
(
dim
));
}
}
void
CompileTimeInferShapeContext
::
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
std
::
vector
<
std
::
vector
<
int64_t
>>
dim_vec
(
dims
.
size
());
std
::
transform
(
dims
.
begin
(),
dims
.
end
(),
dim_vec
.
begin
(),
vectorize
);
var
->
SetShapes
(
dim_vec
);
}
bool
CompileTimeInferShapeContext
::
IsRuntime
()
const
{
return
false
;
}
bool
CompileTimeInferShapeContext
::
IsRuntime
()
const
{
return
false
;
}
proto
::
VarDesc
::
VarType
CompileTimeInferShapeContext
::
GetVarType
(
proto
::
VarDesc
::
VarType
CompileTimeInferShapeContext
::
GetVarType
(
...
...
paddle/framework/operator.cc
浏览文件 @
0bb9c80e
...
@@ -428,13 +428,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -428,13 +428,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
}
}
std
::
vector
<
DDim
>
GetRepeatedDim
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
DDim
>
GetRepeatedDim
s
(
const
std
::
string
&
name
)
const
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
ReaderHolder
>
())
{
if
(
var
->
IsType
<
ReaderHolder
>
())
{
return
var
->
Get
<
ReaderHolder
>
().
shapes
();
return
var
->
Get
<
ReaderHolder
>
().
shapes
();
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
"Only ReaderHolder support 'GetRepeatedDim', but Variable %s's "
"Only ReaderHolder support 'GetRepeatedDim
s
', but Variable %s's "
"type_id is %s."
,
"type_id is %s."
,
name
,
var
->
Type
().
name
());
name
,
var
->
Type
().
name
());
}
}
...
@@ -452,6 +452,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -452,6 +452,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
}
}
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
ReaderHolder
>
())
{
var
->
GetMutable
<
ReaderHolder
>
()
->
set_shapes
(
dims
);
}
else
{
PADDLE_THROW
(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s."
,
name
,
var
->
Type
().
name
());
}
}
proto
::
VarDesc
::
VarType
GetVarType
(
const
std
::
string
&
name
)
const
override
{
proto
::
VarDesc
::
VarType
GetVarType
(
const
std
::
string
&
name
)
const
override
{
auto
*
var
=
scope_
.
FindVar
(
name
);
auto
*
var
=
scope_
.
FindVar
(
name
);
return
ToVarType
(
var
->
Type
());
return
ToVarType
(
var
->
Type
());
...
...
paddle/framework/reader.cc
浏览文件 @
0bb9c80e
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
DDim
FileReader
::
shape
(
size_t
idx
)
const
{
DDim
ReaderBase
::
shape
(
size_t
idx
)
const
{
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
idx
,
shapes_
.
size
(),
idx
,
shapes_
.
size
(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements."
,
idx
,
"Cannot get the %d'th shape, 'shapes_' only has %d elements."
,
idx
,
...
@@ -25,15 +25,15 @@ DDim FileReader::shape(size_t idx) const {
...
@@ -25,15 +25,15 @@ DDim FileReader::shape(size_t idx) const {
return
shapes_
[
idx
];
return
shapes_
[
idx
];
}
}
void
ShuffleReader
::
ReadNext
(
std
::
vector
<
LoD
t
ensor
>*
out
)
{
void
ShuffleReader
::
ReadNext
(
std
::
vector
<
LoD
T
ensor
>*
out
)
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
// Reload buffer with new data
// Reload buffer with new data
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
re
vers
e
(
buffer_size_
);
buffer_
.
re
serv
e
(
buffer_size_
);
for
(
int
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
reader_
->
HasNext
())
{
if
(
reader_
->
HasNext
())
{
buffer
.
push_back
(
std
::
vector
<
LoDTensor
>
());
buffer
_
.
push_back
(
std
::
vector
<
LoDTensor
>
());
reader_
->
ReadNext
(
&
buffer
.
back
());
reader_
->
ReadNext
(
&
buffer
_
.
back
());
}
else
{
}
else
{
break
;
break
;
}
}
...
@@ -48,19 +48,19 @@ void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) {
...
@@ -48,19 +48,19 @@ void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) {
// if buffer_ is empty, the 'out' will return as an empty vector.
// if buffer_ is empty, the 'out' will return as an empty vector.
}
}
void
BatchReader
::
ReadNext
(
std
::
vector
<
LoD
t
ensor
>*
out
)
{
void
BatchReader
::
ReadNext
(
std
::
vector
<
LoD
T
ensor
>*
out
)
{
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reserve
(
batch_size_
);
buffer_
.
reserve
(
batch_size_
);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
if
(
reader_
->
HasNext
())
{
if
(
reader_
->
HasNext
())
{
buffer_
.
push_back
(
std
::
vector
<
LoD
t
ensor
>
());
buffer_
.
push_back
(
std
::
vector
<
LoD
T
ensor
>
());
reader_
->
ReadNext
(
&
buffer_
.
back
());
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
else
{
}
else
{
break
;
break
;
}
}
}
}
// Concat instances
// Concat instances
out
.
clear
();
out
->
clear
();
if
(
buffer_
.
empty
())
{
if
(
buffer_
.
empty
())
{
// if buffer_ is empty, the 'out' will return as an empty vector.
// if buffer_ is empty, the 'out' will return as an empty vector.
return
;
return
;
...
...
paddle/framework/reader.h
浏览文件 @
0bb9c80e
...
@@ -22,39 +22,36 @@ namespace framework {
...
@@ -22,39 +22,36 @@ namespace framework {
class
ReaderBase
{
class
ReaderBase
{
public:
public:
virtual
void
ReadNext
(
std
::
vector
<
LoDtensor
>*
out
)
=
0
;
explicit
ReaderBase
(
const
std
::
vector
<
DDim
>&
shapes
)
:
shapes_
(
shapes
)
{
PADDLE_ENFORCE
(
!
shapes_
.
empty
());
}
virtual
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
virtual
bool
HasNext
()
const
=
0
;
virtual
bool
HasNext
()
const
=
0
;
virtual
DDim
shape
(
size_t
idx
)
const
=
0
;
DDim
shape
(
size_t
idx
)
const
;
virtual
std
::
vector
<
DDim
>
shapes
()
const
=
0
;
std
::
vector
<
DDim
>
shapes
()
const
{
return
shapes_
;
}
void
set_shapes
(
const
std
::
vector
<
DDim
>&
shapes
)
{
shapes_
=
shapes
;
}
virtual
~
ReaderBase
()
{}
virtual
~
ReaderBase
()
{}
protected:
std
::
vector
<
DDim
>
shapes_
;
};
};
class
FileReader
:
public
ReaderBase
{
class
FileReader
:
public
ReaderBase
{
public:
public:
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
shapes
)
:
shapes_
(
shapes
)
{
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
shapes
)
:
ReaderBase
(
shapes
)
{}
PADDLE_ENFORCE
(
!
shapes_
.
empty
());
}
DDim
shape
(
size_t
idx
)
const
override
;
std
::
vector
<
DDim
>
shapes
()
const
override
{
return
shapes_
;
}
protected:
std
::
vector
<
DDim
>
shapes_
;
};
};
class
DecoratedReader
:
public
ReaderBase
{
class
DecoratedReader
:
public
ReaderBase
{
public:
public:
explicit
DecoratedReader
(
ReaderBase
*
reader
)
:
reader_
(
reader
)
{
explicit
DecoratedReader
(
ReaderBase
*
reader
)
:
ReaderBase
(
reader
->
shapes
()),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
}
bool
HasNext
()
const
override
{
return
reader_
->
HasNext
();
}
bool
HasNext
()
const
override
{
return
reader_
->
HasNext
();
}
DDim
shape
(
size_t
idx
)
const
override
{
return
reader_
->
shape
(
idx
);
}
std
::
vector
<
DDim
>
shapes
()
const
override
{
return
reader_
->
shapes
();
}
protected:
protected:
ReaderBase
*
reader_
;
ReaderBase
*
reader_
;
};
};
...
@@ -73,9 +70,9 @@ class RandomReader : public FileReader {
...
@@ -73,9 +70,9 @@ class RandomReader : public FileReader {
dist_
=
std
::
uniform_real_distribution
<
float
>
(
min_
,
max_
);
dist_
=
std
::
uniform_real_distribution
<
float
>
(
min_
,
max_
);
}
}
void
ReadNext
(
std
::
vector
<
LoD
t
ensor
>*
out
)
override
{
void
ReadNext
(
std
::
vector
<
LoD
T
ensor
>*
out
)
override
{
out
.
clear
();
out
->
clear
();
out
.
reserve
(
shapes_
.
size
());
out
->
reserve
(
shapes_
.
size
());
for
(
const
DDim
&
shape
:
shapes_
)
{
for
(
const
DDim
&
shape
:
shapes_
)
{
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
shape
.
size
(),
2
,
shape
.
size
(),
2
,
...
@@ -88,9 +85,8 @@ class RandomReader : public FileReader {
...
@@ -88,9 +85,8 @@ class RandomReader : public FileReader {
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
data
[
i
]
=
dist_
(
engine_
);
data
[
i
]
=
dist_
(
engine_
);
}
}
out
.
push_back
(
out_tensor
);
out
->
push_back
(
out_tensor
);
}
}
return
out
;
}
}
bool
HasNext
()
const
override
{
return
true
;
}
bool
HasNext
()
const
override
{
return
true
;
}
...
@@ -111,11 +107,11 @@ class ShuffleReader : public DecoratedReader {
...
@@ -111,11 +107,11 @@ class ShuffleReader : public DecoratedReader {
buffer_
.
reserve
(
buffer_size
);
buffer_
.
reserve
(
buffer_size
);
}
}
void
ReadNext
(
std
::
vector
<
LoD
t
ensor
>*
out
)
override
;
void
ReadNext
(
std
::
vector
<
LoD
T
ensor
>*
out
)
override
;
private:
private:
int
buffer_size_
;
int
buffer_size_
;
std
::
vector
<
std
::
vector
<
LoD
t
ensor
>>
buffer_
;
std
::
vector
<
std
::
vector
<
LoD
T
ensor
>>
buffer_
;
size_t
iteration_pos_
;
size_t
iteration_pos_
;
};
};
...
@@ -126,11 +122,11 @@ class BatchReader : public DecoratedReader {
...
@@ -126,11 +122,11 @@ class BatchReader : public DecoratedReader {
buffer_
.
reserve
(
batch_size_
);
buffer_
.
reserve
(
batch_size_
);
}
}
void
ReadNext
(
std
::
vector
<
LoD
t
ensor
>*
out
)
override
;
void
ReadNext
(
std
::
vector
<
LoD
T
ensor
>*
out
)
override
;
private:
private:
int
batch_size_
;
int
batch_size_
;
std
::
vector
<
std
::
vector
<
LoD
t
ensor
>>
buffer_
;
std
::
vector
<
std
::
vector
<
LoD
T
ensor
>>
buffer_
;
};
};
// The ReaderHolder is used as readers' unified wrapper,
// The ReaderHolder is used as readers' unified wrapper,
...
@@ -141,11 +137,14 @@ class ReaderHolder {
...
@@ -141,11 +137,14 @@ class ReaderHolder {
ReaderBase
*
Get
()
const
{
return
reader_
.
get
();
}
ReaderBase
*
Get
()
const
{
return
reader_
.
get
();
}
void
ReadNext
(
std
::
vector
<
LoD
t
ensor
>*
out
)
{
reader_
->
ReadNext
(
out
);
}
void
ReadNext
(
std
::
vector
<
LoD
T
ensor
>*
out
)
{
reader_
->
ReadNext
(
out
);
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
DDim
shape
(
size_t
idx
)
const
{
return
reader_
->
shape
(
idx
);
}
DDim
shape
(
size_t
idx
)
const
{
return
reader_
->
shape
(
idx
);
}
std
::
vector
<
DDim
>
shapes
()
const
{
return
reader_
->
shapes
();
}
std
::
vector
<
DDim
>
shapes
()
const
{
return
reader_
->
shapes
();
}
void
set_shapes
(
const
std
::
vector
<
DDim
>&
shapes
)
{
reader_
->
set_shapes
(
shapes
);
}
private:
private:
std
::
unique_ptr
<
ReaderBase
>
reader_
;
std
::
unique_ptr
<
ReaderBase
>
reader_
;
...
...
paddle/framework/shape_inference.cc
浏览文件 @
0bb9c80e
...
@@ -62,6 +62,16 @@ void InferShapeContext::SetOutputsDim(const std::string &name,
...
@@ -62,6 +62,16 @@ void InferShapeContext::SetOutputsDim(const std::string &name,
SetDims
(
names
,
dims
);
SetDims
(
names
,
dims
);
}
}
void
InferShapeContext
::
SetReaderDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Outputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Reader output '%s' should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
return
this
->
SetRepeatedDims
(
arg_names
[
0
],
dims
);
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetDims
(
std
::
vector
<
DDim
>
InferShapeContext
::
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
DDim
>
ret
;
std
::
vector
<
DDim
>
ret
;
...
...
paddle/framework/shape_inference.h
浏览文件 @
0bb9c80e
...
@@ -37,11 +37,12 @@ class InferShapeContext {
...
@@ -37,11 +37,12 @@ class InferShapeContext {
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
;
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
DDim
>
GetReaderDims
(
const
std
::
string
&
name
)
const
DDim
;
std
::
vector
<
DDim
>
GetReaderDims
(
const
std
::
string
&
name
)
const
;
DDim
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
;
DDim
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
;
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
);
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
);
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
void
SetReaderDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
virtual
AttrReader
Attrs
()
const
=
0
;
virtual
AttrReader
Attrs
()
const
=
0
;
virtual
const
std
::
vector
<
std
::
string
>
&
Inputs
(
virtual
const
std
::
vector
<
std
::
string
>
&
Inputs
(
...
@@ -61,7 +62,9 @@ class InferShapeContext {
...
@@ -61,7 +62,9 @@ class InferShapeContext {
protected:
protected:
virtual
DDim
GetDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
DDim
GetDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
virtual
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
std
::
vector
<
DDim
>
GetRepeatedDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
=
0
;
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
std
::
vector
<
proto
::
VarDesc
::
VarType
>
GetVarTypes
(
std
::
vector
<
proto
::
VarDesc
::
VarType
>
GetVarTypes
(
...
...
paddle/framework/var_desc.cc
浏览文件 @
0bb9c80e
...
@@ -57,10 +57,13 @@ size_t VarDesc::GetTensorDescNum() const {
...
@@ -57,10 +57,13 @@ size_t VarDesc::GetTensorDescNum() const {
void
VarDesc
::
SetShapes
(
void
VarDesc
::
SetShapes
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
multiple_dims
)
{
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
multiple_dims
)
{
PADDLE_ENFORCE_EQ
(
multiple_dims
.
size
(),
GetTensorDescNum
(),
if
(
multiple_dims
.
size
()
!=
GetTensorDescNum
())
{
"The number of given shapes(%d) doesn't equal to the "
VLOG
(
3
)
<<
"WARNING: The number of given shapes("
<<
multiple_dims
.
size
()
"number of sub tensor."
,
<<
") doesn't match the existing tensor number("
multiple_dims
.
size
(),
GetTensorDescNum
());
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_dims
.
size
());
}
std
::
vector
<
proto
::
TensorDesc
*>
tensors
=
mutable_tensor_descs
();
std
::
vector
<
proto
::
TensorDesc
*>
tensors
=
mutable_tensor_descs
();
for
(
size_t
i
=
0
;
i
<
multiple_dims
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
multiple_dims
.
size
();
++
i
)
{
VectorToRepeated
(
multiple_dims
[
i
],
tensors
[
i
]
->
mutable_dims
());
VectorToRepeated
(
multiple_dims
[
i
],
tensors
[
i
]
->
mutable_dims
());
...
@@ -87,10 +90,14 @@ void VarDesc::SetDataType(proto::DataType data_type) {
...
@@ -87,10 +90,14 @@ void VarDesc::SetDataType(proto::DataType data_type) {
void
VarDesc
::
SetDataTypes
(
void
VarDesc
::
SetDataTypes
(
const
std
::
vector
<
proto
::
DataType
>
&
multiple_data_type
)
{
const
std
::
vector
<
proto
::
DataType
>
&
multiple_data_type
)
{
PADDLE_ENFORCE_EQ
(
multiple_data_type
.
size
(),
GetTensorDescNum
(),
if
(
multiple_data_type
.
size
()
!=
GetTensorDescNum
())
{
"The number of given data types(%d) doesn't equal to the "
VLOG
(
3
)
<<
"WARNING: The number of given data types("
"number of sub tensor."
,
<<
multiple_data_type
.
size
()
multiple_data_type
.
size
(),
GetTensorDescNum
());
<<
") doesn't match the existing tensor number("
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_data_type
.
size
());
}
std
::
vector
<
proto
::
TensorDesc
*>
tensor_descs
=
mutable_tensor_descs
();
std
::
vector
<
proto
::
TensorDesc
*>
tensor_descs
=
mutable_tensor_descs
();
for
(
size_t
i
=
0
;
i
<
multiple_data_type
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
multiple_data_type
.
size
();
++
i
)
{
tensor_descs
[
i
]
->
set_data_type
(
multiple_data_type
[
i
]);
tensor_descs
[
i
]
->
set_data_type
(
multiple_data_type
[
i
]);
...
@@ -127,10 +134,14 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
...
@@ -127,10 +134,14 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
}
}
void
VarDesc
::
SetLoDLevels
(
const
std
::
vector
<
int32_t
>
&
multiple_lod_level
)
{
void
VarDesc
::
SetLoDLevels
(
const
std
::
vector
<
int32_t
>
&
multiple_lod_level
)
{
PADDLE_ENFORCE_EQ
(
multiple_lod_level
.
size
(),
GetTensorDescNum
(),
if
(
multiple_lod_level
.
size
()
!=
GetTensorDescNum
())
{
"The number of given data types(%d) doesn't equal to the "
VLOG
(
3
)
<<
"WARNING: The number of given lod_levels("
"number of sub tensor."
,
<<
multiple_lod_level
.
size
()
multiple_lod_level
.
size
(),
GetTensorDescNum
());
<<
") doesn't match the existing tensor number("
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_lod_level
.
size
());
}
switch
(
desc_
.
type
())
{
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
READER
:
{
case
proto
::
VarDesc
::
READER
:
{
size_t
i
=
0
;
size_t
i
=
0
;
...
...
paddle/framework/var_type.h
浏览文件 @
0bb9c80e
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/reader.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/variable.h"
#include "paddle/framework/variable.h"
...
@@ -31,6 +32,8 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
...
@@ -31,6 +32,8 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
return
proto
::
VarDesc_VarType_LOD_TENSOR_ARRAY
;
return
proto
::
VarDesc_VarType_LOD_TENSOR_ARRAY
;
}
else
if
(
type
.
hash_code
()
==
typeid
(
SelectedRows
).
hash_code
())
{
}
else
if
(
type
.
hash_code
()
==
typeid
(
SelectedRows
).
hash_code
())
{
return
proto
::
VarDesc_VarType_SELECTED_ROWS
;
return
proto
::
VarDesc_VarType_SELECTED_ROWS
;
}
else
if
(
type
.
hash_code
()
==
typeid
(
ReaderHolder
).
hash_code
())
{
return
proto
::
VarDesc_VarType_READER
;
}
else
{
}
else
{
PADDLE_THROW
(
"ToVarType:Unsupported type %s"
,
type
.
name
());
PADDLE_THROW
(
"ToVarType:Unsupported type %s"
,
type
.
name
());
}
}
...
@@ -40,7 +43,7 @@ template <typename Visitor>
...
@@ -40,7 +43,7 @@ template <typename Visitor>
inline
void
VisitVarType
(
const
framework
::
Variable
&
var
,
Visitor
visitor
)
{
inline
void
VisitVarType
(
const
framework
::
Variable
&
var
,
Visitor
visitor
)
{
switch
(
ToVarType
(
var
.
Type
()))
{
switch
(
ToVarType
(
var
.
Type
()))
{
case
proto
::
VarDesc_VarType_LOD_TENSOR
:
case
proto
::
VarDesc_VarType_LOD_TENSOR
:
visitor
(
var
.
Get
<
framework
::
LoDTensor
>
());
visitor
(
var
.
Get
<
LoDTensor
>
());
return
;
return
;
case
proto
::
VarDesc_VarType_LOD_RANK_TABLE
:
case
proto
::
VarDesc_VarType_LOD_RANK_TABLE
:
visitor
(
var
.
Get
<
LoDRankTable
>
());
visitor
(
var
.
Get
<
LoDRankTable
>
());
...
@@ -51,6 +54,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
...
@@ -51,6 +54,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case
proto
::
VarDesc_VarType_SELECTED_ROWS
:
case
proto
::
VarDesc_VarType_SELECTED_ROWS
:
visitor
(
var
.
Get
<
SelectedRows
>
());
visitor
(
var
.
Get
<
SelectedRows
>
());
return
;
return
;
case
proto
::
VarDesc_VarType_READER
:
visitor
(
var
.
Get
<
ReaderHolder
>
());
return
;
default:
default:
PADDLE_THROW
(
"Not supported visit type, %d"
,
ToVarType
(
var
.
Type
()));
PADDLE_THROW
(
"Not supported visit type, %d"
,
ToVarType
(
var
.
Type
()));
}
}
...
...
paddle/operators/create_reader_op.cc
浏览文件 @
0bb9c80e
...
@@ -18,12 +18,30 @@
...
@@ -18,12 +18,30 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
)
{
std
::
vector
<
framework
::
DDim
>
res
;
int
offset
=
0
;
for
(
int
len
:
ranks
)
{
auto
start_it
=
shape_concat
.
begin
()
+
offset
;
auto
end_it
=
start_it
+
len
;
res
.
push_back
(
framework
::
make_ddim
(
std
::
vector
<
int
>
(
start_it
,
end_it
)));
offset
+=
len
;
}
return
res
;
}
// general infershape for file readers
// general infershape for file readers
class
CreateFileReaderInferShape
:
public
framework
::
InferShapeBase
{
class
CreateFileReaderInferShape
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output file reader should not be null."
);
"The output file reader should not be null."
);
const
auto
shape_concat
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
ranks
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ranks"
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
ctx
->
SetReaderDims
(
"Out"
,
shapes
);
}
}
};
};
...
@@ -31,10 +49,22 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
...
@@ -31,10 +49,22 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
class
CreateDecoratedReaderInferShape
:
public
framework
::
InferShapeBase
{
class
CreateDecoratedReaderInferShape
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Underlying
_r
eader"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Underlying
R
eader"
),
"Input(Underlying
_r
eader) should not be null."
);
"Input(Underlying
R
eader) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output decorated reader should not be null."
);
"The output decorated reader should not be null."
);
ctx
->
SetReaderDims
(
"Out"
,
ctx
->
GetReaderDims
(
"UnderlyingReader"
));
}
};
// general var type inference for all readers
class
CreateReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
string
reader_name
=
op_desc
.
Output
(
"Out"
)[
0
];
framework
::
VarDesc
*
reader
=
block
->
FindVarRecursive
(
reader_name
);
reader
->
SetType
(
framework
::
proto
::
VarDesc
::
READER
);
}
}
};
};
...
@@ -51,15 +81,7 @@ class CreateRandomReaderOp : public framework::OperatorBase {
...
@@ -51,15 +81,7 @@ class CreateRandomReaderOp : public framework::OperatorBase {
int
(
shape_concat
.
size
()),
int
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
;
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
int
offset
=
0
;
for
(
int
len
:
ranks
)
{
auto
start_it
=
shape_concat
.
begin
()
+
offset
;
auto
end_it
=
start_it
+
len
;
shapes
.
push_back
(
framework
::
make_ddim
(
std
::
vector
<
int
>
(
start_it
,
end_it
)));
offset
+=
len
;
}
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
framework
::
RandomReader
<
T
>
(
shapes
,
Attr
<
float
>
(
"min"
),
out
->
Reset
(
new
framework
::
RandomReader
<
T
>
(
shapes
,
Attr
<
float
>
(
"min"
),
...
@@ -99,7 +121,7 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
...
@@ -99,7 +121,7 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"Underlying
_r
eader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"Underlying
R
eader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
...
@@ -113,7 +135,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -113,7 +135,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateShuffleReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
CreateShuffleReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddInput
(
AddInput
(
"Underlying
_r
eader"
,
"Underlying
R
eader"
,
"(ReaderHolder) The underlying reader for creating a shuffle reader."
);
"(ReaderHolder) The underlying reader for creating a shuffle reader."
);
AddOutput
(
"Out"
,
"(ReaderHolder) The created shuffle reader."
);
AddOutput
(
"Out"
,
"(ReaderHolder) The created shuffle reader."
);
AddAttr
<
int
>
(
"buffer_size"
,
"The shuffle buffer size."
).
GreaterThan
(
0
);
AddAttr
<
int
>
(
"buffer_size"
,
"The shuffle buffer size."
).
GreaterThan
(
0
);
...
@@ -131,7 +153,7 @@ class CreateBatchReaderOp : public framework::OperatorBase {
...
@@ -131,7 +153,7 @@ class CreateBatchReaderOp : public framework::OperatorBase {
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"Underlying
_r
eader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"Underlying
R
eader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
...
@@ -145,7 +167,7 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -145,7 +167,7 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateBatchReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
CreateBatchReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
op_proto
,
op_checker
)
{
AddInput
(
AddInput
(
"Underlying
_r
eader"
,
"Underlying
R
eader"
,
"(ReaderHolder) The underlying reader for creating a batch reader."
);
"(ReaderHolder) The underlying reader for creating a batch reader."
);
AddOutput
(
"Out"
,
"(ReaderHolder) The created batch reader."
);
AddOutput
(
"Out"
,
"(ReaderHolder) The created batch reader."
);
AddAttr
<
int
>
(
"batch_size"
,
AddAttr
<
int
>
(
"batch_size"
,
...
@@ -167,12 +189,15 @@ namespace ops = paddle::operators;
...
@@ -167,12 +189,15 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
create_random_reader
,
ops
::
CreateRandomReaderOp
<
float
>
,
REGISTER_OPERATOR
(
create_random_reader
,
ops
::
CreateRandomReaderOp
<
float
>
,
ops
::
CreateFileReaderInferShape
,
ops
::
CreateFileReaderInferShape
,
ops
::
CreateRandomReaderOpMaker
,
ops
::
CreateRandomReaderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateReaderInferVarType
);
REGISTER_OPERATOR
(
create_shuffle_reader
,
ops
::
CreateShuffleReaderOp
,
REGISTER_OPERATOR
(
create_shuffle_reader
,
ops
::
CreateShuffleReaderOp
,
ops
::
CreateDecoratedReaderInferShape
,
ops
::
CreateDecoratedReaderInferShape
,
ops
::
CreateShuffleReaderOpMaker
,
ops
::
CreateShuffleReaderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateReaderInferVarType
);
REGISTER_OPERATOR
(
create_batch_reader
,
ops
::
CreateBatchReaderOp
,
REGISTER_OPERATOR
(
create_batch_reader
,
ops
::
CreateBatchReaderOp
,
ops
::
CreateDecoratedReaderInferShape
,
ops
::
CreateDecoratedReaderInferShape
,
ops
::
CreateBatchReaderOpMaker
,
ops
::
CreateBatchReaderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
CreateReaderInferVarType
);
paddle/operators/read_op.cc
浏览文件 @
0bb9c80e
...
@@ -25,7 +25,7 @@ class ReadInferShape : public framework::InferShapeBase {
...
@@ -25,7 +25,7 @@ class ReadInferShape : public framework::InferShapeBase {
"The ReadOp must take a reader as input."
);
"The ReadOp must take a reader as input."
);
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
"Out"
),
"The ReadOp should be assigned with output."
);
"The ReadOp should be assigned with output."
);
std
::
vector
<
DDim
>
reader_dims
=
ctx
->
GetReaderDims
(
"Reader"
);
std
::
vector
<
framework
::
DDim
>
reader_dims
=
ctx
->
GetReaderDims
(
"Reader"
);
std
::
vector
<
std
::
string
>
out_names
=
ctx
->
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
out_names
=
ctx
->
Outputs
(
"Out"
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
reader_dims
.
size
(),
out_names
.
size
(),
reader_dims
.
size
(),
out_names
.
size
(),
...
@@ -40,12 +40,12 @@ class ReadInferVarType : public framework::VarTypeInference {
...
@@ -40,12 +40,12 @@ class ReadInferVarType : public framework::VarTypeInference {
framework
::
BlockDesc
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
std
::
string
reader_name
=
op_desc
.
Input
(
"Reader"
)[
0
];
std
::
string
reader_name
=
op_desc
.
Input
(
"Reader"
)[
0
];
std
::
vector
<
std
::
string
>
out_names
=
op_desc
.
Output
(
"Out"
);
std
::
vector
<
std
::
string
>
out_names
=
op_desc
.
Output
(
"Out"
);
framework
::
VarDesc
reader
=
block
.
FindVarRecursive
(
reader_name
);
framework
::
VarDesc
*
reader
=
block
->
FindVarRecursive
(
reader_name
);
auto
dtypes
=
reader
.
GetDataTypes
();
auto
dtypes
=
reader
->
GetDataTypes
();
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
out_names
.
size
());
PADDLE_ENFORCE_EQ
(
dtypes
.
size
(),
out_names
.
size
());
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
f
arem
work
::
VarDesc
&
out
=
block
->
FindRecursiveOrCreateVar
(
out_names
[
i
]);
f
rame
work
::
VarDesc
&
out
=
block
->
FindRecursiveOrCreateVar
(
out_names
[
i
]);
out
.
SetType
(
framework
::
proto
::
DataType
::
LOD_TENSOR
);
out
.
SetType
(
framework
::
proto
::
VarDesc
::
LOD_TENSOR
);
out
.
SetDataType
(
dtypes
[
i
]);
out
.
SetDataType
(
dtypes
[
i
]);
}
}
}
}
...
@@ -56,20 +56,18 @@ class ReadOp : public framework::OperatorBase {
...
@@ -56,20 +56,18 @@ class ReadOp : public framework::OperatorBase {
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
const
framework
::
ReaderHolder
&
reader
=
framework
::
ReaderHolder
*
reader
=
scope
.
FindVar
(
Input
(
"Reader"
))
->
Get
<
ReaderHolder
>
();
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
!
reader
.
HasNext
())
{
if
(
!
reader
->
HasNext
())
{
// what shall we do???
return
;
return
;
}
}
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
framework
::
LoDTensor
>
ins
;
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
.
ReadNext
(
&
ins
);
reader
->
ReadNext
(
&
ins
);
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
auto
*
out
=
auto
*
out
=
scope
.
FindVar
(
out_arg_names
[
i
])
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
.
FindVar
(
out_arg_names
[
i
])
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
ins
[
i
].
dims
(),
out
->
dims
());
out
->
ShareDataWith
(
ins
[
i
]);
out
->
ShareDataWith
(
ins
[
i
]);
out
->
set_lod
(
ins
[
i
].
lod
());
out
->
set_lod
(
ins
[
i
].
lod
());
}
}
...
@@ -86,9 +84,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -86,9 +84,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
Read Operator
Read Operator
Execute a given reader once and output data.
Execute a given reader once and output data.
)DOC"
)
)DOC"
)
;
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
read
,
ops
::
ReadOp
,
ops
::
ReadInferShape
,
ops
::
ReadOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
ReadInferVarType
);
paddle/pybind/protobuf.cc
浏览文件 @
0bb9c80e
...
@@ -217,8 +217,6 @@ void BindVarDsec(py::module &m) {
...
@@ -217,8 +217,6 @@ void BindVarDsec(py::module &m) {
.
def
(
"set_shapes"
,
&
VarDesc
::
SetShapes
)
.
def
(
"set_shapes"
,
&
VarDesc
::
SetShapes
)
.
def
(
"set_dtype"
,
&
VarDesc
::
SetDataType
)
.
def
(
"set_dtype"
,
&
VarDesc
::
SetDataType
)
.
def
(
"set_dtypes"
,
&
VarDesc
::
SetDataTypes
)
.
def
(
"set_dtypes"
,
&
VarDesc
::
SetDataTypes
)
.
def
(
"set_tensor_num"
,
&
VarDesc
::
SetTensorDescNum
)
.
def
(
"tensor_num"
,
&
VarDesc
::
GetTensorDescNum
)
.
def
(
"shape"
,
&
VarDesc
::
GetShape
,
py
::
return_value_policy
::
reference
)
.
def
(
"shape"
,
&
VarDesc
::
GetShape
,
py
::
return_value_policy
::
reference
)
.
def
(
"shapes"
,
&
VarDesc
::
GetShapes
,
py
::
return_value_policy
::
reference
)
.
def
(
"shapes"
,
&
VarDesc
::
GetShapes
,
py
::
return_value_policy
::
reference
)
.
def
(
"dtype"
,
&
VarDesc
::
GetDataType
,
py
::
return_value_policy
::
reference
)
.
def
(
"dtype"
,
&
VarDesc
::
GetDataType
,
py
::
return_value_policy
::
reference
)
...
...
python/paddle/v2/fluid/executor.py
浏览文件 @
0bb9c80e
...
@@ -51,7 +51,8 @@ def as_numpy(tensor):
...
@@ -51,7 +51,8 @@ def as_numpy(tensor):
if
len
(
lod
)
==
0
:
if
len
(
lod
)
==
0
:
ans
=
tensor_data
ans
=
tensor_data
else
:
else
:
raise
RuntimeError
(
"LoD Calculate lacks unit tests and buggy"
)
#raise RuntimeError("LoD Calculate lacks unit tests and buggy")
ans
=
tensor_data
# elif len(lod) == 1:
# elif len(lod) == 1:
# ans = []
# ans = []
# idx = 0
# idx = 0
...
...
python/paddle/v2/fluid/tests/test_cpp_reader.py
0 → 100644
浏览文件 @
0bb9c80e
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
import
numpy
as
np
prog
=
fluid
.
framework
.
Program
()
block
=
prog
.
current_block
()
random_reader
=
block
.
create_var
(
type
=
fluid
.
core
.
VarDesc
.
VarType
.
READER
,
name
=
"RandomReader"
)
random_reader
.
desc
.
set_lod_levels
([
0
,
0
])
create_random_reader_op
=
block
.
append_op
(
type
=
"create_random_reader"
,
outputs
=
{
"Out"
:
random_reader
},
attrs
=
{
"shape_concat"
:
[
1
,
2
,
1
,
1
],
"ranks"
:
[
2
,
2
],
"min"
:
0.0
,
"max"
:
1.0
})
batch_reader
=
block
.
create_var
(
type
=
fluid
.
core
.
VarDesc
.
VarType
.
READER
,
name
=
(
"BatchReader"
))
batch_reader
.
desc
.
set_lod_levels
([
0
,
0
])
create_batch_reader_op
=
block
.
append_op
(
type
=
"create_batch_reader"
,
inputs
=
{
"UnderlyingReader"
:
random_reader
},
outputs
=
{
"Out"
:
batch_reader
},
attrs
=
{
"batch_size"
:
10
})
out1
=
block
.
create_var
(
type
=
fluid
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
name
=
"Out1"
,
shape
=
[
10
,
2
],
dtype
=
"float32"
,
lod_level
=
1
)
out2
=
block
.
create_var
(
type
=
fluid
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
name
=
"Out2"
,
shape
=
[
10
,
1
],
dtype
=
"float32"
,
lod_level
=
1
)
read_op
=
block
.
append_op
(
type
=
"read"
,
inputs
=
{
"Reader"
:
batch_reader
},
outputs
=
{
"Out"
:
[
out1
,
out2
]})
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
[
res1
,
res2
]
=
exe
.
run
(
prog
,
fetch_list
=
[
out1
,
out2
])
if
len
(
res1
)
==
0
or
len
(
res2
)
==
0
:
exit
(
1
)
exit
(
0
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录