Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c8ba6d51
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c8ba6d51
编写于
2月 06, 2018
作者:
F
fengjiayi
提交者:
GitHub
2月 06, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8135 from JiayiFeng/dev_make_VarDesc_supporting_multiple_tensor
Add type `Reader` for `VarDesc`
上级
445c74cd
e5227c2c
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
244 addition
and
24 deletion
+244
-24
paddle/framework/backward.cc
paddle/framework/backward.cc
+2
-2
paddle/framework/framework.proto
paddle/framework/framework.proto
+7
-3
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+2
-2
paddle/framework/program_desc_test.cc
paddle/framework/program_desc_test.cc
+2
-2
paddle/framework/var_desc.cc
paddle/framework/var_desc.cc
+163
-11
paddle/framework/var_desc.h
paddle/framework/var_desc.h
+19
-1
paddle/inference/io.cc
paddle/inference/io.cc
+1
-1
paddle/pybind/protobuf.cc
paddle/pybind/protobuf.cc
+12
-2
python/paddle/v2/fluid/tests/test_protobuf_descs.py
python/paddle/v2/fluid/tests/test_protobuf_descs.py
+36
-0
未找到文件。
paddle/framework/backward.cc
浏览文件 @
c8ba6d51
...
...
@@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward(
auto
root_block
=
program_desc
.
MutableBlock
(
root_block_idx
);
std
::
string
fill_one_op_out
=
GradVarName
(
target
.
Name
());
bool
is_scalar
=
target
.
Shape
()
==
std
::
vector
<
int64_t
>
{
1
};
bool
is_scalar
=
target
.
Get
Shape
()
==
std
::
vector
<
int64_t
>
{
1
};
PADDLE_ENFORCE
(
is_scalar
,
"target should be scalar"
);
VLOG
(
3
)
<<
"backward from loss="
<<
target
.
Name
()
<<
" data_type="
<<
target
.
GetDataType
();
...
...
@@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward(
auto
var
=
root_block
->
Var
(
fill_one_op_out
);
var
->
SetDataType
(
target
.
GetDataType
());
var
->
SetShape
(
target
.
Shape
());
var
->
SetShape
(
target
.
Get
Shape
());
auto
&
target_grad
=
retv
[
target
.
Name
()];
target_grad
.
name_
=
fill_one_op_out
;
target_grad
.
block_idx_
=
root_block_idx
;
...
...
paddle/framework/framework.proto
浏览文件 @
c8ba6d51
...
...
@@ -116,6 +116,8 @@ message LoDTensorArrayDesc {
optional
int32
lod_level
=
2
[
default
=
0
];
}
message
Reader
{
repeated
LoDTensorDesc
lod_tensor
=
1
;
}
message
VarDesc
{
enum
VarType
{
LOD_TENSOR
=
1
;
...
...
@@ -126,13 +128,15 @@ message VarDesc {
LOD_RANK_TABLE
=
6
;
LOD_TENSOR_ARRAY
=
7
;
PLACE_LIST
=
8
;
READER
=
9
;
}
required
string
name
=
1
;
required
VarType
type
=
2
;
optional
LoDTensorDesc
lod_tensor
=
3
;
optional
TensorDesc
selected_rows
=
4
;
optional
bool
persistable
=
3
[
default
=
false
];
optional
LoDTensorDesc
lod_tensor
=
4
;
optional
TensorDesc
selected_rows
=
5
;
optional
LoDTensorArrayDesc
tensor_array
=
6
;
optional
bool
persistable
=
5
[
default
=
false
]
;
optional
Reader
reader
=
7
;
}
message
BlockDesc
{
...
...
paddle/framework/op_desc.cc
浏览文件 @
c8ba6d51
...
...
@@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
try
{
auto
shape
=
var
->
Shape
();
auto
shape
=
var
->
Get
Shape
();
if
(
shape
.
empty
())
{
return
framework
::
make_ddim
({
0UL
});
}
else
{
return
framework
::
make_ddim
(
var
->
Shape
());
return
framework
::
make_ddim
(
var
->
Get
Shape
());
}
}
catch
(...)
{
VLOG
(
5
)
<<
"GetDim of variable "
<<
name
<<
" error"
;
...
...
paddle/framework/program_desc_test.cc
浏览文件 @
c8ba6d51
...
...
@@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_NE
(
copy
,
var_before
);
ASSERT_EQ
(
copy
->
Name
(),
var_before
->
Name
());
ASSERT_EQ
(
copy
->
GetType
(),
var_before
->
GetType
());
ASSERT_EQ
(
copy
->
Shape
(),
var_before
->
Shape
());
ASSERT_EQ
(
copy
->
GetShape
(),
var_before
->
Get
Shape
());
ASSERT_EQ
(
copy
->
Proto
()
->
SerializeAsString
(),
var_before
->
Proto
()
->
SerializeAsString
());
};
...
...
@@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ASSERT_NE
(
restored
,
var_before
);
ASSERT_EQ
(
restored
->
Name
(),
var_before
->
Name
());
ASSERT_EQ
(
restored
->
GetType
(),
var_before
->
GetType
());
ASSERT_EQ
(
restored
->
Shape
(),
var_before
->
Shape
());
ASSERT_EQ
(
restored
->
GetShape
(),
var_before
->
Get
Shape
());
ASSERT_EQ
(
restored
->
Proto
()
->
SerializeAsString
(),
var_before
->
Proto
()
->
SerializeAsString
());
};
...
...
paddle/framework/var_desc.cc
浏览文件 @
c8ba6d51
...
...
@@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated
(
dims
,
mutable_tensor_desc
()
->
mutable_dims
());
}
void
VarDesc
::
SetTensorDescNum
(
size_t
num
)
{
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
READER
:
{
auto
*
lod_tensors_ptr
=
desc_
.
mutable_reader
()
->
mutable_lod_tensor
();
lod_tensors_ptr
->
Clear
();
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
lod_tensors_ptr
->
Add
();
}
return
;
}
break
;
default:
PADDLE_THROW
(
"Setting 'sub_tensor_number' is not supported by the type of var %s."
,
this
->
Name
());
}
}
size_t
VarDesc
::
GetTensorDescNum
()
const
{
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
READER
:
return
desc_
.
reader
().
lod_tensor_size
();
break
;
default:
PADDLE_THROW
(
"Getting 'sub_tensor_number' is not supported by the type of var %s."
,
this
->
Name
());
}
}
void
VarDesc
::
SetShapes
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
multiple_dims
)
{
PADDLE_ENFORCE_EQ
(
multiple_dims
.
size
(),
GetTensorDescNum
(),
"The number of given shapes(%d) doesn't equal to the "
"number of sub tensor."
,
multiple_dims
.
size
(),
GetTensorDescNum
());
std
::
vector
<
proto
::
TensorDesc
*>
tensors
=
mutable_tensor_descs
();
for
(
size_t
i
=
0
;
i
<
multiple_dims
.
size
();
++
i
)
{
VectorToRepeated
(
multiple_dims
[
i
],
tensors
[
i
]
->
mutable_dims
());
}
}
std
::
vector
<
int64_t
>
VarDesc
::
GetShape
()
const
{
return
RepeatedToVector
(
tensor_desc
().
dims
());
}
std
::
vector
<
std
::
vector
<
int64_t
>>
VarDesc
::
GetShapes
()
const
{
std
::
vector
<
proto
::
TensorDesc
>
descs
=
tensor_descs
();
std
::
vector
<
std
::
vector
<
int64_t
>>
res
;
res
.
reserve
(
descs
.
size
());
for
(
const
auto
&
tensor_desc
:
descs
)
{
res
.
push_back
(
RepeatedToVector
(
tensor_desc
.
dims
()));
}
return
res
;
}
void
VarDesc
::
SetDataType
(
proto
::
DataType
data_type
)
{
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
}
std
::
vector
<
int64_t
>
VarDesc
::
Shape
()
const
{
return
RepeatedToVector
(
tensor_desc
().
dims
());
void
VarDesc
::
SetDataTypes
(
const
std
::
vector
<
proto
::
DataType
>
&
multiple_data_type
)
{
PADDLE_ENFORCE_EQ
(
multiple_data_type
.
size
(),
GetTensorDescNum
(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor."
,
multiple_data_type
.
size
(),
GetTensorDescNum
());
std
::
vector
<
proto
::
TensorDesc
*>
tensor_descs
=
mutable_tensor_descs
();
for
(
size_t
i
=
0
;
i
<
multiple_data_type
.
size
();
++
i
)
{
tensor_descs
[
i
]
->
set_data_type
(
multiple_data_type
[
i
]);
}
}
proto
::
DataType
VarDesc
::
GetDataType
()
const
{
return
tensor_desc
().
data_type
();
}
std
::
vector
<
proto
::
DataType
>
VarDesc
::
GetDataTypes
()
const
{
std
::
vector
<
proto
::
TensorDesc
>
descs
=
tensor_descs
();
std
::
vector
<
proto
::
DataType
>
res
;
res
.
reserve
(
descs
.
size
());
for
(
const
auto
&
tensor_desc
:
descs
)
{
res
.
push_back
(
tensor_desc
.
data_type
());
}
return
res
;
}
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
LOD_TENSOR
:
...
...
@@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
desc_
.
mutable_tensor_array
()
->
set_lod_level
(
lod_level
);
break
;
default:
PADDLE_THROW
(
"Tensor type=%d does not support LoDLevel"
,
desc_
.
tensor_array
().
lod_level
());
PADDLE_THROW
(
"Setting 'lod_level' is not supported by the type of var %s."
,
this
->
Name
());
}
}
void
VarDesc
::
SetLoDLevels
(
const
std
::
vector
<
int32_t
>
&
multiple_lod_level
)
{
PADDLE_ENFORCE_EQ
(
multiple_lod_level
.
size
(),
GetTensorDescNum
(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor."
,
multiple_lod_level
.
size
(),
GetTensorDescNum
());
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
READER
:
{
size_t
i
=
0
;
for
(
auto
&
lod_tensor
:
*
desc_
.
mutable_reader
()
->
mutable_lod_tensor
())
{
lod_tensor
.
set_lod_level
(
multiple_lod_level
[
i
++
]);
}
}
break
;
default:
PADDLE_THROW
(
"Setting 'lod_levels' is not supported by the type of var %s."
,
this
->
Name
());
}
}
...
...
@@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const {
case
proto
::
VarDesc
::
LOD_TENSOR_ARRAY
:
return
desc_
.
tensor_array
().
lod_level
();
default:
PADDLE_THROW
(
"Tensor type=%d does not support LoDLevel"
,
desc_
.
tensor_array
().
lod_level
());
PADDLE_THROW
(
"Getting 'lod_level' is not supported by the type of var %s."
,
this
->
Name
());
}
}
std
::
vector
<
int32_t
>
VarDesc
::
GetLoDLevels
()
const
{
std
::
vector
<
int32_t
>
res
;
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
READER
:
res
.
reserve
(
desc_
.
reader
().
lod_tensor_size
());
for
(
auto
&
lod_tensor
:
desc_
.
reader
().
lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
lod_level
());
}
return
res
;
break
;
default:
PADDLE_THROW
(
"Getting 'lod_levels' is not supported by the type of var %s."
,
this
->
Name
());
}
}
const
proto
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"
invoke TensorDesc must after set type
"
);
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"
The var's type hasn't been set.
"
);
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
SELECTED_ROWS
:
return
desc_
.
selected_rows
();
...
...
@@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
case
proto
::
VarDesc
::
LOD_TENSOR_ARRAY
:
return
desc_
.
tensor_array
().
tensor
();
default:
PADDLE_THROW
(
"The type of var %s is unsupported."
,
this
->
Name
());
PADDLE_THROW
(
"Getting 'tensor_desc' is not supported by the type of var %s."
,
this
->
Name
());
}
}
std
::
vector
<
proto
::
TensorDesc
>
VarDesc
::
tensor_descs
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
std
::
vector
<
proto
::
TensorDesc
>
res
;
res
.
reserve
(
GetTensorDescNum
());
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
READER
:
for
(
const
auto
&
lod_tensor
:
desc_
.
reader
().
lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
tensor
());
}
return
res
;
default:
PADDLE_THROW
(
"Getting 'tensor_descs' is not supported by the type of var "
"%s."
,
this
->
Name
());
}
}
proto
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"invoke MutableTensorDesc must after set type"
);
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
SELECTED_ROWS
:
return
desc_
.
mutable_selected_rows
();
...
...
@@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
case
proto
::
VarDesc
::
LOD_TENSOR_ARRAY
:
return
desc_
.
mutable_tensor_array
()
->
mutable_tensor
();
default:
PADDLE_THROW
(
"Unexpected branch."
);
PADDLE_THROW
(
"Getting 'mutable_tensor_desc' is not supported by the type of var "
"%s."
,
this
->
Name
());
}
}
std
::
vector
<
proto
::
TensorDesc
*>
VarDesc
::
mutable_tensor_descs
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
std
::
vector
<
proto
::
TensorDesc
*>
res
;
res
.
reserve
(
GetTensorDescNum
());
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
READER
:
for
(
auto
&
lod_tensor
:
*
desc_
.
mutable_reader
()
->
mutable_lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
mutable_tensor
());
}
return
res
;
default:
PADDLE_THROW
(
"Getting 'tensor_descs' is not supported by the type of var "
"%s."
,
this
->
Name
());
}
}
}
// namespace framework
}
// namespace paddle
paddle/framework/var_desc.h
浏览文件 @
c8ba6d51
...
...
@@ -68,18 +68,34 @@ class VarDesc {
void
SetName
(
std
::
string
name
)
{
desc_
.
set_name
(
name
);
}
void
SetTensorDescNum
(
size_t
num
);
size_t
GetTensorDescNum
()
const
;
void
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
);
void
SetShapes
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
multiple_dims
);
std
::
vector
<
int64_t
>
GetShape
()
const
;
std
::
vector
<
std
::
vector
<
int64_t
>>
GetShapes
()
const
;
void
SetDataType
(
proto
::
DataType
data_type
);
std
::
vector
<
int64_t
>
Shape
()
const
;
void
SetDataTypes
(
const
std
::
vector
<
proto
::
DataType
>
&
multiple_data_type
)
;
proto
::
DataType
GetDataType
()
const
;
std
::
vector
<
proto
::
DataType
>
GetDataTypes
()
const
;
void
SetLoDLevel
(
int32_t
lod_level
);
void
SetLoDLevels
(
const
std
::
vector
<
int32_t
>
&
multiple_lod_level
);
int32_t
GetLoDLevel
()
const
;
std
::
vector
<
int32_t
>
GetLoDLevels
()
const
;
proto
::
VarDesc
::
VarType
GetType
()
const
;
void
SetType
(
proto
::
VarDesc
::
VarType
type
);
...
...
@@ -90,7 +106,9 @@ class VarDesc {
private:
const
proto
::
TensorDesc
&
tensor_desc
()
const
;
std
::
vector
<
proto
::
TensorDesc
>
tensor_descs
()
const
;
proto
::
TensorDesc
*
mutable_tensor_desc
();
std
::
vector
<
proto
::
TensorDesc
*>
mutable_tensor_descs
();
proto
::
VarDesc
desc_
;
};
...
...
paddle/inference/io.cc
浏览文件 @
c8ba6d51
...
...
@@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor,
VLOG
(
3
)
<<
"parameter's name: "
<<
var
->
Name
();
framework
::
VarDesc
*
new_var
=
load_block
->
Var
(
var
->
Name
());
new_var
->
SetShape
(
var
->
Shape
());
new_var
->
SetShape
(
var
->
Get
Shape
());
new_var
->
SetDataType
(
var
->
GetDataType
());
new_var
->
SetType
(
var
->
GetType
());
new_var
->
SetLoDLevel
(
var
->
GetLoDLevel
());
...
...
paddle/pybind/protobuf.cc
浏览文件 @
c8ba6d51
...
...
@@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) {
py
::
return_value_policy
::
reference
)
.
def
(
"set_name"
,
&
VarDesc
::
SetName
)
.
def
(
"set_shape"
,
&
VarDesc
::
SetShape
)
.
def
(
"set_shapes"
,
&
VarDesc
::
SetShapes
)
.
def
(
"set_dtype"
,
&
VarDesc
::
SetDataType
)
.
def
(
"shape"
,
&
VarDesc
::
Shape
,
py
::
return_value_policy
::
reference
)
.
def
(
"set_dtypes"
,
&
VarDesc
::
SetDataTypes
)
.
def
(
"set_tensor_num"
,
&
VarDesc
::
SetTensorDescNum
)
.
def
(
"tensor_num"
,
&
VarDesc
::
GetTensorDescNum
)
.
def
(
"shape"
,
&
VarDesc
::
GetShape
,
py
::
return_value_policy
::
reference
)
.
def
(
"shapes"
,
&
VarDesc
::
GetShapes
,
py
::
return_value_policy
::
reference
)
.
def
(
"dtype"
,
&
VarDesc
::
GetDataType
,
py
::
return_value_policy
::
reference
)
.
def
(
"dtypes"
,
&
VarDesc
::
GetDataTypes
,
py
::
return_value_policy
::
reference
)
.
def
(
"lod_level"
,
&
VarDesc
::
GetLoDLevel
)
.
def
(
"lod_levels"
,
&
VarDesc
::
GetLoDLevels
,
py
::
return_value_policy
::
reference
)
.
def
(
"set_lod_level"
,
&
VarDesc
::
SetLoDLevel
)
.
def
(
"set_lod_levels"
,
&
VarDesc
::
SetLoDLevels
)
.
def
(
"type"
,
&
VarDesc
::
GetType
)
.
def
(
"set_type"
,
&
VarDesc
::
SetType
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
VarDesc
>
)
...
...
@@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) {
.
value
(
"STEP_SCOPES"
,
proto
::
VarDesc
::
STEP_SCOPES
)
.
value
(
"LOD_RANK_TABLE"
,
proto
::
VarDesc
::
LOD_RANK_TABLE
)
.
value
(
"LOD_TENSOR_ARRAY"
,
proto
::
VarDesc
::
LOD_TENSOR_ARRAY
)
.
value
(
"PLACE_LIST"
,
proto
::
VarDesc
::
PLACE_LIST
);
.
value
(
"PLACE_LIST"
,
proto
::
VarDesc
::
PLACE_LIST
)
.
value
(
"READER"
,
proto
::
VarDesc
::
READER
);
}
void
BindOpDesc
(
py
::
module
&
m
)
{
...
...
python/paddle/v2/fluid/tests/test_protobuf_descs.py
浏览文件 @
c8ba6d51
...
...
@@ -115,6 +115,18 @@ class TestVarDesc(unittest.TestCase):
self
.
assertEqual
(
src_shape
,
res_shape
)
self
.
assertEqual
(
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
var
.
type
())
def
test_multiple_shape
(
self
):
program_desc
=
core
.
ProgramDesc
()
block
=
program_desc
.
block
(
0
)
var
=
block
.
var
(
'my_reader'
)
var
.
set_type
(
core
.
VarDesc
.
VarType
.
READER
)
var
.
set_tensor_num
(
3
)
src_shapes
=
[[
2
,
3
,
3
],
[
4
,
5
],
[
6
,
7
,
8
,
9
]]
var
.
set_shapes
(
src_shapes
)
res_shapes
=
var
.
shapes
()
self
.
assertEqual
(
src_shapes
,
res_shapes
)
self
.
assertEqual
(
core
.
VarDesc
.
VarType
.
READER
,
var
.
type
())
def
test_dtype
(
self
):
program_desc
=
core
.
ProgramDesc
()
block
=
program_desc
.
block
(
0
)
...
...
@@ -124,6 +136,30 @@ class TestVarDesc(unittest.TestCase):
self
.
assertEqual
(
core
.
DataType
.
INT32
,
var
.
dtype
())
self
.
assertEqual
(
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
var
.
type
())
def
test_multiple_dtype
(
self
):
program_desc
=
core
.
ProgramDesc
()
block
=
program_desc
.
block
(
0
)
var
=
block
.
var
(
'my_reader'
)
var
.
set_type
(
core
.
VarDesc
.
VarType
.
READER
)
var
.
set_tensor_num
(
3
)
src_types
=
[
core
.
DataType
.
INT32
,
core
.
DataType
.
FP64
,
core
.
DataType
.
FP32
]
var
.
set_dtypes
(
src_types
)
self
.
assertEqual
(
src_types
,
var
.
dtypes
())
self
.
assertEqual
(
core
.
VarDesc
.
VarType
.
READER
,
var
.
type
())
def
test_multiple_lod_level
(
self
):
program_desc
=
core
.
ProgramDesc
()
block
=
program_desc
.
block
(
0
)
var
=
block
.
var
(
'my_reader'
)
var
.
set_type
(
core
.
VarDesc
.
VarType
.
READER
)
var
.
set_tensor_num
(
3
)
src_types
=
[
3
,
1
,
2
]
var
.
set_lod_levels
(
src_types
)
self
.
assertEqual
(
src_types
,
var
.
lod_levels
())
self
.
assertEqual
(
core
.
VarDesc
.
VarType
.
READER
,
var
.
type
())
class
TestBlockDesc
(
unittest
.
TestCase
):
def
test_add_var
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录