Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f9f910a3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f9f910a3
编写于
9月 25, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete op
上级
1cd20140
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
181 addition
and
5 deletion
+181
-5
paddle/pybind/protobuf.cc
paddle/pybind/protobuf.cc
+146
-4
python/paddle/v2/framework/tests/test_protobuf_descs.py
python/paddle/v2/framework/tests/test_protobuf_descs.py
+35
-1
未找到文件。
paddle/pybind/protobuf.cc
浏览文件 @
f9f910a3
...
...
@@ -14,8 +14,72 @@ limitations under the License. */
#include "paddle/pybind/protobuf.h"
#include <deque>
#include <iostream>
#include "paddle/framework/attribute.h"
// Cast boost::variant for PyBind.
// Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
namespace
pybind11
{
namespace
detail
{
// Can be replaced by a generic lambda in C++14
struct
variant_caster_visitor
:
public
boost
::
static_visitor
<
handle
>
{
return_value_policy
policy
;
handle
parent
;
variant_caster_visitor
(
return_value_policy
policy
,
handle
parent
)
:
policy
(
policy
),
parent
(
parent
)
{}
template
<
class
T
>
handle
operator
()(
T
const
&
src
)
const
{
return
make_caster
<
T
>::
cast
(
src
,
policy
,
parent
);
}
};
template
<
class
Variant
>
struct
variant_caster
;
template
<
template
<
class
...
>
class
V
,
class
...
Ts
>
struct
variant_caster
<
V
<
Ts
...
>>
{
using
Type
=
V
<
Ts
...
>
;
template
<
class
T
>
bool
try_load
(
handle
src
,
bool
convert
)
{
auto
caster
=
make_caster
<
T
>
();
if
(
!
load_success_
&&
caster
.
load
(
src
,
convert
))
{
load_success_
=
true
;
value
=
cast_op
<
T
>
(
caster
);
return
true
;
}
return
false
;
}
bool
load
(
handle
src
,
bool
convert
)
{
auto
unused
=
{
false
,
try_load
<
Ts
>
(
src
,
convert
)...};
(
void
)(
unused
);
return
load_success_
;
}
static
handle
cast
(
Type
const
&
src
,
return_value_policy
policy
,
handle
parent
)
{
variant_caster_visitor
visitor
(
policy
,
parent
);
return
boost
::
apply_visitor
(
visitor
,
src
);
}
PYBIND11_TYPE_CASTER
(
Type
,
_
(
"Variant"
));
bool
load_success_
{
false
};
};
// Add specialization for concrete variant type
template
<
class
...
Args
>
struct
type_caster
<
boost
::
variant
<
Args
...
>>
:
variant_caster
<
boost
::
variant
<
Args
...
>>
{};
}
// namespace detail
}
// namespace pybind11
namespace
paddle
{
namespace
pybind
{
...
...
@@ -40,6 +104,15 @@ inline void VectorToRepeated(const std::vector<T> &vec,
}
}
template
<
typename
RepeatedField
>
inline
void
VectorToRepeated
(
const
std
::
vector
<
bool
>
&
vec
,
RepeatedField
*
repeated_field
)
{
repeated_field
->
Reserve
(
vec
.
size
());
for
(
auto
elem
:
vec
)
{
*
repeated_field
->
Add
()
=
elem
;
}
}
class
ProgramDescBind
;
class
OpDescBind
;
class
BlockDescBind
;
...
...
@@ -146,6 +219,10 @@ public:
void
operator
()(
const
std
::
vector
<
bool
>
&
v
)
const
{
VectorToRepeated
(
v
,
attr_
->
mutable_bools
());
}
void
operator
()(
BlockDesc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
idx
());
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
};
void
Sync
()
{
...
...
@@ -168,13 +245,52 @@ public:
for
(
auto
&
attr
:
attrs_
)
{
auto
*
attr_desc
=
op_desc_
.
add_attrs
();
attr_desc
->
set_name
(
attr
.
first
);
attr_desc
->
set_type
(
static_cast
<
AttrType
>
(
attr
.
second
.
which
()
-
1
));
attr_desc
->
set_type
(
static_cast
<
framework
::
AttrType
>
(
attr
.
second
.
which
()
-
1
));
boost
::
apply_visitor
(
SetAttrDescVisitor
(
attr_desc
),
attr
.
second
);
}
need_update_
=
false
;
}
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
{
return
attrs_
.
find
(
name
)
!=
attrs_
.
end
();
}
framework
::
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
return
static_cast
<
framework
::
AttrType
>
(
it
->
second
.
which
()
-
1
);
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
{
std
::
vector
<
std
::
string
>
retv
;
retv
.
reserve
(
attrs_
.
size
());
for
(
auto
&
attr
:
attrs_
)
{
retv
.
push_back
(
attr
.
first
);
}
return
retv
;
}
void
SetAttr
(
const
std
::
string
&
name
,
const
Attribute
&
v
)
{
this
->
attrs_
[
name
]
=
v
;
}
void
SetBlockAttr
(
const
std
::
string
&
name
,
BlockDescBind
&
block
);
int
GetBlockAttr
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
return
boost
::
get
<
BlockDesc
*>
(
it
->
second
)
->
idx
();
}
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
return
it
->
second
;
}
private:
OpDesc
op_desc_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs_
;
...
...
@@ -232,6 +348,8 @@ public:
}
}
BlockDesc
*
RawPtr
()
{
return
desc_
;
}
private:
ProgramDescBind
*
prog_
;
// not_own
BlockDesc
*
desc_
;
// not_own
...
...
@@ -303,6 +421,11 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return
prog_
->
Block
(
static_cast
<
size_t
>
(
this
->
desc_
->
parent_idx
()));
}
void
OpDescBind
::
SetBlockAttr
(
const
std
::
string
&
name
,
BlockDescBind
&
block
)
{
BlockDesc
*
desc
=
block
.
RawPtr
();
this
->
attrs_
[
name
]
=
desc
;
}
void
BindProgramDesc
(
py
::
module
&
m
)
{
py
::
class_
<
ProgramDescBind
>
(
m
,
"ProgramDesc"
,
""
)
.
def_static
(
"instance"
,
...
...
@@ -351,8 +474,19 @@ void BindVarDsec(py::module &m) {
}
void
BindOpDesc
(
py
::
module
&
m
)
{
py
::
class_
<
OpDescBind
>
(
m
,
"OpDesc"
,
""
)
.
def
(
"type"
,
&
OpDescBind
::
Type
)
py
::
enum_
<
framework
::
AttrType
>
(
m
,
"AttrType"
,
""
)
.
value
(
"INT"
,
AttrType
::
INT
)
.
value
(
"INTS"
,
AttrType
::
INTS
)
.
value
(
"FLOAT"
,
AttrType
::
FLOAT
)
.
value
(
"FLOATS"
,
AttrType
::
FLOATS
)
.
value
(
"STRING"
,
AttrType
::
STRING
)
.
value
(
"STRINGS"
,
AttrType
::
STRINGS
)
.
value
(
"BOOL"
,
AttrType
::
BOOLEAN
)
.
value
(
"BOOLS"
,
AttrType
::
BOOLEANS
)
.
value
(
"BLOCK"
,
AttrType
::
BLOCK
);
py
::
class_
<
OpDescBind
>
op_desc
(
m
,
"OpDesc"
,
""
);
op_desc
.
def
(
"type"
,
&
OpDescBind
::
Type
)
.
def
(
"set_type"
,
&
OpDescBind
::
SetType
)
.
def
(
"input"
,
&
OpDescBind
::
Input
)
.
def
(
"input_names"
,
&
OpDescBind
::
InputNames
)
...
...
@@ -361,7 +495,15 @@ void BindOpDesc(py::module &m) {
.
def
(
"output_names"
,
&
OpDescBind
::
OutputNames
)
.
def
(
"set_output"
,
&
OpDescBind
::
SetOutput
)
.
def
(
"__str__"
,
&
OpDescBind
::
DebugString
)
.
def
(
"__repr__"
,
&
OpDescBind
::
DebugString
);
.
def
(
"__repr__"
,
&
OpDescBind
::
DebugString
)
.
def
(
"has_attr"
,
&
OpDescBind
::
HasAttr
)
.
def
(
"attr_type"
,
&
OpDescBind
::
GetAttrType
)
.
def
(
"attr_names"
,
&
OpDescBind
::
AttrNames
)
.
def
(
"set_attr"
,
&
OpDescBind
::
SetAttr
)
.
def
(
"attr"
,
&
OpDescBind
::
GetAttr
)
.
def
(
"set_block_attr"
,
&
OpDescBind
::
SetBlockAttr
)
.
def
(
"get_block_attr"
,
&
OpDescBind
::
GetBlockAttr
);
}
}
// namespace pybind
}
// namespace paddle
python/paddle/v2/framework/tests/test_protobuf_descs.py
浏览文件 @
f9f910a3
...
...
@@ -20,6 +20,40 @@ class TestOpDesc(unittest.TestCase):
self
.
assertEqual
([
'z'
],
op
.
output
(
"Out"
))
self
.
assertEqual
([
"Out"
],
op
.
output_names
())
op
.
set_attr
(
"int_attr"
,
1
)
self
.
assertEqual
(
1
,
op
.
attr
(
"int_attr"
))
self
.
assertTrue
(
op
.
has_attr
(
"int_attr"
))
op
.
set_attr
(
"float_attr"
,
-
1.32
)
self
.
assertAlmostEqual
(
-
1.32
,
op
.
attr
(
"float_attr"
),
delta
=
1e-4
)
self
.
assertTrue
(
op
.
has_attr
(
"float_attr"
))
op
.
set_attr
(
"bool_attr"
,
False
)
self
.
assertFalse
(
op
.
attr
(
"bool_attr"
))
op
.
set_attr
(
"string_attr"
,
"abc"
)
self
.
assertEqual
(
"abc"
,
op
.
attr
(
"string_attr"
))
self
.
assertTrue
(
op
.
has_attr
(
"string_attr"
))
op
.
set_attr
(
"ints_attr"
,
[
1
,
2
,
3
])
self
.
assertEqual
([
1
,
2
,
3
],
op
.
attr
(
"ints_attr"
))
expected
=
[
1.2
,
2.3
,
3.4
]
op
.
set_attr
(
"floats_attr"
,
expected
)
for
e
,
a
in
zip
(
expected
,
op
.
attr
(
"floats_attr"
)):
self
.
assertAlmostEqual
(
e
,
a
,
delta
=
1e-4
)
op
.
set_attr
(
"strings_attr"
,
[
"a"
,
"b"
,
"c"
])
self
.
assertEqual
([
"a"
,
"b"
,
"c"
],
op
.
attr
(
"strings_attr"
))
op
.
set_attr
(
"bools_attr"
,
[
True
,
False
,
True
])
self
.
assertEqual
([
True
,
False
,
True
],
op
.
attr
(
"bools_attr"
))
self
.
assertEqual
(
8
,
len
(
op
.
attr_names
()))
op
.
set_block_attr
(
"block_attr"
,
prog
.
block
(
0
))
self
.
assertEqual
(
0
,
op
.
get_block_attr
(
"block_attr"
))
class
TestProgramDesc
(
unittest
.
TestCase
):
def
test_instance
(
self
):
...
...
@@ -51,7 +85,7 @@ class TestProgramDesc(unittest.TestCase):
class
TestVarDesc
(
unittest
.
TestCase
):
def
test_shape
(
self
):
program_desc
=
core
.
ProgramDesc
.
instance
()
block
=
program_desc
.
root_block
(
)
block
=
program_desc
.
block
(
0
)
var
=
block
.
new_var
(
'my_var'
)
src_shape
=
[
3
,
2
,
10
,
8
]
var
.
set_shape
(
src_shape
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录