Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d1e85e33
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d1e85e33
编写于
10月 25, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
shape type to int64_t, test=develop
上级
39b3bf24
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
119 addition
and
84 deletion
+119
-84
paddle/fluid/framework/attribute.h
paddle/fluid/framework/attribute.h
+115
-82
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+4
-2
未找到文件。
paddle/fluid/framework/attribute.h
浏览文件 @
d1e85e33
...
@@ -26,6 +26,113 @@ limitations under the License. */
...
@@ -26,6 +26,113 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
template
<
typename
T
>
struct
ExtractAttribute
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
T
*
operator
()(
Attribute
&
attr
)
const
{
T
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
T
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type %s, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
typeid
(
T
).
name
()),
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template
<
>
struct
ExtractAttribute
<
bool
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
bool
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
bool
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
float
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
bool
>
(
val
);
}
bool
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
bool
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type bool, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
>
struct
ExtractAttribute
<
int64_t
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
int64_t
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
int64_t
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
int64_t
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
>
struct
ExtractAttribute
<
std
::
vector
<
int64_t
>>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
std
::
vector
<
int64_t
>*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
std
::
vector
<
int
>
))
{
// NOLINT
std
::
vector
<
int
>
val
=
boost
::
get
<
std
::
vector
<
int
>>
(
attr
);
std
::
vector
<
int64_t
>
vec
(
val
.
begin
(),
val
.
end
());
attr
=
vec
;
}
else
if
(
attr
.
type
()
==
typeid
(
std
::
vector
<
float
>
))
{
// NOLINT
std
::
vector
<
float
>
val
=
boost
::
get
<
std
::
vector
<
float
>>
(
attr
);
std
::
vector
<
int64_t
>
vec
(
val
.
begin
(),
val
.
end
());
attr
=
vec
;
}
std
::
vector
<
int64_t
>*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
std
::
vector
<
int64_t
>>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
typename
T
>
template
<
typename
T
>
inline
proto
::
AttrType
AttrTypeID
()
{
inline
proto
::
AttrType
AttrTypeID
()
{
Attribute
tmp
=
T
();
Attribute
tmp
=
T
();
...
@@ -42,7 +149,11 @@ class AttrReader {
...
@@ -42,7 +149,11 @@ class AttrReader {
inline
const
T
&
Get
(
const
std
::
string
&
name
)
const
{
inline
const
T
&
Get
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
count
(
name
)
!=
0
,
"%s should be in AttributeMap"
,
PADDLE_ENFORCE
(
attrs_
.
count
(
name
)
!=
0
,
"%s should be in AttributeMap"
,
name
);
name
);
return
boost
::
get
<
T
>
(
attrs_
.
at
(
name
));
Attribute
&
attr
=
const_cast
<
Attribute
&>
(
attrs_
.
at
(
name
));
ExtractAttribute
<
T
>
extract_attr
(
name
);
T
*
attr_value
=
extract_attr
(
attr
);
return
*
attr_value
;
}
}
private:
private:
...
@@ -82,7 +193,7 @@ class DefaultValueSetter {
...
@@ -82,7 +193,7 @@ class DefaultValueSetter {
public:
public:
explicit
DefaultValueSetter
(
T
default_value
)
explicit
DefaultValueSetter
(
T
default_value
)
:
default_value_
(
default_value
)
{}
:
default_value_
(
default_value
)
{}
void
operator
()(
T
&
value
)
const
{
value
=
default_value_
;
}
void
operator
()(
T
&
value
)
const
{
value
=
default_value_
;
}
// NOLINT
private:
private:
T
default_value_
;
T
default_value_
;
...
@@ -117,84 +228,6 @@ class EnumInContainer {
...
@@ -117,84 +228,6 @@ class EnumInContainer {
std
::
unordered_set
<
T
>
container_
;
std
::
unordered_set
<
T
>
container_
;
};
};
template
<
typename
T
>
struct
ExtractAttribute
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
T
*
operator
()(
Attribute
&
attr
)
const
{
T
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
T
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type %s, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
typeid
(
T
).
name
()),
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template
<
>
struct
ExtractAttribute
<
bool
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
bool
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
bool
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
float
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
bool
>
(
val
);
}
bool
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
bool
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type bool, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
>
struct
ExtractAttribute
<
int64_t
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
int64_t
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
int64_t
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
int64_t
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// check whether a certain attribute fit its limits
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
// an attribute can have more than one limits
template
<
typename
T
>
template
<
typename
T
>
...
@@ -235,7 +268,7 @@ class TypedAttrChecker {
...
@@ -235,7 +268,7 @@ class TypedAttrChecker {
return
*
this
;
return
*
this
;
}
}
void
operator
()(
AttributeMap
&
attr_map
)
const
{
void
operator
()(
AttributeMap
&
attr_map
)
const
{
// NOLINT
if
(
!
attr_map
.
count
(
attr_name_
))
{
if
(
!
attr_map
.
count
(
attr_name_
))
{
// user do not set this attr
// user do not set this attr
PADDLE_ENFORCE
(
!
default_value_setter_
.
empty
(),
PADDLE_ENFORCE
(
!
default_value_setter_
.
empty
(),
...
@@ -271,7 +304,7 @@ class OpAttrChecker {
...
@@ -271,7 +304,7 @@ class OpAttrChecker {
return
*
(
checker
.
target
<
TypedAttrChecker
<
T
>>
());
return
*
(
checker
.
target
<
TypedAttrChecker
<
T
>>
());
}
}
void
Check
(
AttributeMap
&
attr_map
)
const
{
void
Check
(
AttributeMap
&
attr_map
)
const
{
// NOLINT
for
(
const
auto
&
checker
:
attr_checkers_
)
{
for
(
const
auto
&
checker
:
attr_checkers_
)
{
checker
(
attr_map
);
checker
(
attr_map
);
}
}
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
d1e85e33
...
@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
...
@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void
operator
()(
const
std
::
vector
<
BlockDesc
*>
&
v
)
const
{
void
operator
()(
const
std
::
vector
<
BlockDesc
*>
&
v
)
const
{
std
::
vector
<
int
>
blocks_idx
;
std
::
vector
<
int
>
blocks_idx
;
for
(
auto
blk
:
v
)
{
for
(
auto
blk
:
v
)
{
blocks_idx
.
push_back
(
blk
->
ID
());
blocks_idx
.
push_
s
back
(
blk
->
ID
());
}
}
VectorToRepeated
(
blocks_idx
,
attr_
->
mutable_blocks_idx
());
VectorToRepeated
(
blocks_idx
,
attr_
->
mutable_blocks_idx
());
}
}
void
operator
()(
BlockDesc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
BlockDesapply_visitorc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
int64_t
v
)
const
{
attr_
->
set_l
(
v
);
}
void
operator
()(
int64_t
v
)
const
{
attr_
->
set_l
(
v
);
}
void
operator
()(
const
std
::
vector
<
int64_t
>
&
v
)
const
{
void
operator
()(
const
std
::
vector
<
int64_t
>
&
v
)
const
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录