Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ca909408
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ca909408
编写于
4月 27, 2022
作者:
C
Chen Weihang
提交者:
GitHub
4月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt attr eaque perf (#42272)
上级
88d68c08
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
49 addition
and
80 deletion
+49
-80
paddle/fluid/framework/attribute.h
paddle/fluid/framework/attribute.h
+5
-0
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+11
-21
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+19
-32
paddle/fluid/imperative/prepared_operator.h
paddle/fluid/imperative/prepared_operator.h
+14
-27
未找到文件。
paddle/fluid/framework/attribute.h
浏览文件 @
ca909408
...
...
@@ -203,12 +203,17 @@ struct ExtractAttribute<std::vector<double>> {
const
std
::
string
&
attr_name_
;
};
template
<
typename
T
>
inline
proto
::
AttrType
AttrTypeID
()
{
Attribute
tmp
=
T
();
return
static_cast
<
proto
::
AttrType
>
(
tmp
.
which
()
-
1
);
}
inline
proto
::
AttrType
AttrTypeID
(
const
Attribute
&
attr
)
{
return
static_cast
<
proto
::
AttrType
>
(
attr
.
which
()
-
1
);
}
class
AttrReader
{
public:
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
...
...
paddle/fluid/framework/infershape_utils.cc
浏览文件 @
ca909408
...
...
@@ -501,16 +501,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
}
else
if
(
ctx
->
HasAttr
(
attr_name
))
{
auto
&
attr
=
attr_reader
.
GetAttr
(
attr_name
);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int32_t
>
)))
{
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INTS
)
{
infer_meta_context
.
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
LONGS
)
{
infer_meta_context
.
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
int
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INT
)
{
infer_meta_context
.
EmplaceBackAttr
(
phi
::
IntArray
({
BOOST_GET_CONST
(
int
,
attr
)}));
}
else
{
...
...
@@ -524,15 +521,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
if
(
ctx
->
HasAttr
(
attr_name
))
{
// TODO(chentianyu03): support other attrs later
auto
&
attr
=
attr_reader
.
GetAttr
(
attr_name
);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
float
))
)
{
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
FLOAT
)
{
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
BOOST_GET_CONST
(
float
,
attr
)));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
string
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
STRING
)
{
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
BOOST_GET_CONST
(
std
::
string
,
attr
)));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
int
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INT
)
{
infer_meta_context
.
EmplaceBackAttr
(
phi
::
Scalar
(
BOOST_GET_CONST
(
int
,
attr
)));
}
else
{
...
...
@@ -562,8 +557,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
}
else
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
SCALARS
)
{
auto
&
attr
=
attr_reader
.
GetAttr
(
attr_name
);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int32_t
>
)))
{
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INTS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -571,8 +565,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list
.
emplace_back
(
val
);
}
infer_meta_context
.
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
LONGS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -580,8 +573,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list
.
emplace_back
(
val
);
}
infer_meta_context
.
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
float
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
FLOATS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -589,8 +581,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list
.
emplace_back
(
val
);
}
infer_meta_context
.
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
double
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
FLOAT64S
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
double
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -624,8 +615,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context
.
EmplaceBackAttr
(
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
INT64S
)
{
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int
>
)))
{
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INTS
)
{
// Emplace Back Attr according to the type of Phi_Kernel args.
const
auto
&
vector_int_attr
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
attr
);
const
std
::
vector
<
int64_t
>
vector_int64_attr
(
vector_int_attr
.
begin
(),
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
ca909408
...
...
@@ -2420,18 +2420,16 @@ void OperatorWithKernel::BuildPhiKernelContext(
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
INT_ARRAY
)
{
auto
attr_iter
=
Attrs
().
find
(
attr_names
[
i
]);
if
(
attr_iter
!=
Attrs
().
end
())
{
// shape is in the attribute
if
(
std
::
type_index
(
attr_iter
->
second
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr_iter
->
second
))));
}
else
if
(
std
::
type_index
(
attr_iter
->
second
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int32_t
>
)))
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr_iter
->
second
))));
}
else
if
(
std
::
type_index
(
attr_iter
->
second
.
type
())
==
std
::
type_index
(
typeid
(
int32_t
)))
{
auto
&
attr
=
attr_iter
->
second
;
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
LONGS
)
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
&
BOOST_GET_CONST
(
int32_t
,
attr_iter
->
second
),
1
)));
phi
::
IntArray
(
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
))));
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INTS
)
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr
))));
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INT
)
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
&
BOOST_GET_CONST
(
int32_t
,
attr
),
1
)));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported cast op attribute `%s` to IntArray when "
...
...
@@ -2449,21 +2447,16 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
}
}
else
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
SCALAR
)
{
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
auto
attr_iter
=
Attrs
().
find
(
attr_names
[
i
]);
if
(
attr_iter
!=
Attrs
().
end
())
{
// scalar is in the attribute
auto
&
attr
=
Attrs
().
at
(
attr_names
[
i
])
;
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
float
))
)
{
auto
&
attr
=
attr_iter
->
second
;
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
FLOAT
)
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
BOOST_GET_CONST
(
float
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
string
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
STRING
)
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
BOOST_GET_CONST
(
std
::
string
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
int
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INT
)
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
BOOST_GET_CONST
(
int
,
attr
))));
}
else
{
...
...
@@ -2480,8 +2473,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
else
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
SCALARS
)
{
auto
&
attr
=
Attrs
().
at
(
attr_names
[
i
]);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int32_t
>
)))
{
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
INTS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -2489,8 +2481,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list
.
emplace_back
(
val
);
}
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
LONGS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -2498,8 +2489,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list
.
emplace_back
(
val
);
}
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
float
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
FLOATS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -2507,8 +2497,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list
.
emplace_back
(
val
);
}
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
double
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
proto
::
AttrType
::
FLOAT64S
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
double
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -2559,12 +2548,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
BOOST_GET_CONST
(
int
,
attr_it
->
second
)));
pt_kernel_context
->
EmplaceBackAttr
(
data_type
);
}
else
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
INT64S
)
{
if
(
std
::
type_index
(
attr_it
->
second
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
if
(
AttrTypeID
(
attr_it
->
second
)
==
proto
::
AttrType
::
LONGS
)
{
pt_kernel_context
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr_it
->
second
));
}
else
if
(
std
::
type_index
(
attr_it
->
second
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int
>
)))
{
}
else
if
(
AttrTypeID
(
attr_it
->
second
)
==
proto
::
AttrType
::
INTS
)
{
// Emplace Back Attr according to the type of Phi_Kernel args.
const
auto
&
vector_int_attr
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
attr_it
->
second
);
...
...
paddle/fluid/imperative/prepared_operator.h
浏览文件 @
ca909408
...
...
@@ -382,20 +382,16 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
if
(
attrs
.
find
(
attr_names
[
i
])
!=
attrs
.
end
())
{
// shape is in the attribute
auto
&
attr
=
GetAttr
(
attrs
,
default_attrs
,
attr_names
[
i
]);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
LONGS
)
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int32_t
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
INTS
)
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
int64_t
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
LONG
)
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
&
BOOST_GET_CONST
(
int64_t
,
attr
),
1
)));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
int32_t
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
INT
)
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
&
BOOST_GET_CONST
(
int32_t
,
attr
),
1
)));
}
else
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
INT32S
)
{
...
...
@@ -429,15 +425,13 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
default_attrs
.
find
(
attr_names
[
i
])
!=
default_attrs
.
end
())
{
// scalar is in the attribute
auto
&
attr
=
GetAttr
(
attrs
,
default_attrs
,
attr_names
[
i
]);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
float
))
)
{
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
FLOAT
)
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
BOOST_GET_CONST
(
float
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
string
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
STRING
)
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
BOOST_GET_CONST
(
std
::
string
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
int
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
INT
)
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
BOOST_GET_CONST
(
int
,
attr
))));
}
else
{
...
...
@@ -465,8 +459,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
}
}
else
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
SCALARS
)
{
auto
&
attr
=
GetAttr
(
attrs
,
default_attrs
,
attr_names
[
i
]);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int32_t
>
)))
{
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
INTS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -474,8 +467,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
LONGS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -483,8 +475,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
float
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
FLOATS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -492,8 +483,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
double
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
FLOAT64S
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
double
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -501,8 +491,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
bool
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
BOOLEANS
)
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
bool
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
...
...
@@ -534,12 +523,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
BOOST_GET_CONST
(
int
,
attr
)));
kernel_ctx
->
EmplaceBackAttr
(
data_type
);
}
else
if
(
attr_defs
[
i
].
type_index
==
phi
::
AttributeType
::
INT64S
)
{
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
LONGS
)
{
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int
>
)))
{
}
else
if
(
AttrTypeID
(
attr
)
==
framework
::
proto
::
AttrType
::
INTS
)
{
// Emplace Back Attr according to the type of Phi_Kernel args.
const
auto
&
vector_int_attr
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
attr
);
const
std
::
vector
<
int64_t
>
vector_int64_attr
(
vector_int_attr
.
begin
(),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录