Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d1e85e33
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d1e85e33
编写于
10月 25, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
shape type to int64_t, test=develop
上级
39b3bf24
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
119 addition
and
84 deletion
+119
-84
paddle/fluid/framework/attribute.h
paddle/fluid/framework/attribute.h
+115
-82
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+4
-2
未找到文件。
paddle/fluid/framework/attribute.h
浏览文件 @
d1e85e33
...
@@ -26,6 +26,113 @@ limitations under the License. */
...
@@ -26,6 +26,113 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
template
<
typename
T
>
struct
ExtractAttribute
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
T
*
operator
()(
Attribute
&
attr
)
const
{
T
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
T
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type %s, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
typeid
(
T
).
name
()),
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template
<
>
struct
ExtractAttribute
<
bool
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
bool
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
bool
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
float
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
bool
>
(
val
);
}
bool
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
bool
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type bool, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
>
struct
ExtractAttribute
<
int64_t
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
int64_t
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
int64_t
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
int64_t
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
>
struct
ExtractAttribute
<
std
::
vector
<
int64_t
>>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
std
::
vector
<
int64_t
>*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
std
::
vector
<
int
>
))
{
// NOLINT
std
::
vector
<
int
>
val
=
boost
::
get
<
std
::
vector
<
int
>>
(
attr
);
std
::
vector
<
int64_t
>
vec
(
val
.
begin
(),
val
.
end
());
attr
=
vec
;
}
else
if
(
attr
.
type
()
==
typeid
(
std
::
vector
<
float
>
))
{
// NOLINT
std
::
vector
<
float
>
val
=
boost
::
get
<
std
::
vector
<
float
>>
(
attr
);
std
::
vector
<
int64_t
>
vec
(
val
.
begin
(),
val
.
end
());
attr
=
vec
;
}
std
::
vector
<
int64_t
>*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
std
::
vector
<
int64_t
>>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
typename
T
>
template
<
typename
T
>
inline
proto
::
AttrType
AttrTypeID
()
{
inline
proto
::
AttrType
AttrTypeID
()
{
Attribute
tmp
=
T
();
Attribute
tmp
=
T
();
...
@@ -42,7 +149,11 @@ class AttrReader {
...
@@ -42,7 +149,11 @@ class AttrReader {
inline
const
T
&
Get
(
const
std
::
string
&
name
)
const
{
inline
const
T
&
Get
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
count
(
name
)
!=
0
,
"%s should be in AttributeMap"
,
PADDLE_ENFORCE
(
attrs_
.
count
(
name
)
!=
0
,
"%s should be in AttributeMap"
,
name
);
name
);
return
boost
::
get
<
T
>
(
attrs_
.
at
(
name
));
Attribute
&
attr
=
const_cast
<
Attribute
&>
(
attrs_
.
at
(
name
));
ExtractAttribute
<
T
>
extract_attr
(
name
);
T
*
attr_value
=
extract_attr
(
attr
);
return
*
attr_value
;
}
}
private:
private:
...
@@ -82,7 +193,7 @@ class DefaultValueSetter {
...
@@ -82,7 +193,7 @@ class DefaultValueSetter {
public:
public:
explicit
DefaultValueSetter
(
T
default_value
)
explicit
DefaultValueSetter
(
T
default_value
)
:
default_value_
(
default_value
)
{}
:
default_value_
(
default_value
)
{}
void
operator
()(
T
&
value
)
const
{
value
=
default_value_
;
}
void
operator
()(
T
&
value
)
const
{
value
=
default_value_
;
}
// NOLINT
private:
private:
T
default_value_
;
T
default_value_
;
...
@@ -117,84 +228,6 @@ class EnumInContainer {
...
@@ -117,84 +228,6 @@ class EnumInContainer {
std
::
unordered_set
<
T
>
container_
;
std
::
unordered_set
<
T
>
container_
;
};
};
template
<
typename
T
>
struct
ExtractAttribute
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
T
*
operator
()(
Attribute
&
attr
)
const
{
T
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
T
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type %s, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
typeid
(
T
).
name
()),
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template
<
>
struct
ExtractAttribute
<
bool
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
bool
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
bool
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
float
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
bool
>
(
val
);
}
bool
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
bool
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type bool, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
template
<
>
struct
ExtractAttribute
<
int64_t
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
int64_t
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
int64_t
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
int64_t
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
paddle
::
platform
::
demangle
(
attr
.
type
().
name
()));
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// check whether a certain attribute fit its limits
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
// an attribute can have more than one limits
template
<
typename
T
>
template
<
typename
T
>
...
@@ -235,7 +268,7 @@ class TypedAttrChecker {
...
@@ -235,7 +268,7 @@ class TypedAttrChecker {
return
*
this
;
return
*
this
;
}
}
void
operator
()(
AttributeMap
&
attr_map
)
const
{
void
operator
()(
AttributeMap
&
attr_map
)
const
{
// NOLINT
if
(
!
attr_map
.
count
(
attr_name_
))
{
if
(
!
attr_map
.
count
(
attr_name_
))
{
// user do not set this attr
// user do not set this attr
PADDLE_ENFORCE
(
!
default_value_setter_
.
empty
(),
PADDLE_ENFORCE
(
!
default_value_setter_
.
empty
(),
...
@@ -271,7 +304,7 @@ class OpAttrChecker {
...
@@ -271,7 +304,7 @@ class OpAttrChecker {
return
*
(
checker
.
target
<
TypedAttrChecker
<
T
>>
());
return
*
(
checker
.
target
<
TypedAttrChecker
<
T
>>
());
}
}
void
Check
(
AttributeMap
&
attr_map
)
const
{
void
Check
(
AttributeMap
&
attr_map
)
const
{
// NOLINT
for
(
const
auto
&
checker
:
attr_checkers_
)
{
for
(
const
auto
&
checker
:
attr_checkers_
)
{
checker
(
attr_map
);
checker
(
attr_map
);
}
}
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
d1e85e33
...
@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
...
@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void
operator
()(
const
std
::
vector
<
BlockDesc
*>
&
v
)
const
{
void
operator
()(
const
std
::
vector
<
BlockDesc
*>
&
v
)
const
{
std
::
vector
<
int
>
blocks_idx
;
std
::
vector
<
int
>
blocks_idx
;
for
(
auto
blk
:
v
)
{
for
(
auto
blk
:
v
)
{
blocks_idx
.
push_back
(
blk
->
ID
());
blocks_idx
.
push_
s
back
(
blk
->
ID
());
}
}
VectorToRepeated
(
blocks_idx
,
attr_
->
mutable_blocks_idx
());
VectorToRepeated
(
blocks_idx
,
attr_
->
mutable_blocks_idx
());
}
}
void
operator
()(
BlockDesc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
BlockDesapply_visitorc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
int64_t
v
)
const
{
attr_
->
set_l
(
v
);
}
void
operator
()(
int64_t
v
)
const
{
attr_
->
set_l
(
v
);
}
void
operator
()(
const
std
::
vector
<
int64_t
>
&
v
)
const
{
void
operator
()(
const
std
::
vector
<
int64_t
>
&
v
)
const
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录