Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c2a05a90
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c2a05a90
编写于
4月 25, 2022
作者:
C
Chen Weihang
提交者:
GitHub
4月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
replace any by variant in infermeta (#42181)
上级
a3a6f0cf
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
66 addition
and
96 deletion
+66
-96
paddle/phi/core/infermeta_utils.cc
paddle/phi/core/infermeta_utils.cc
+33
-1
paddle/phi/core/infermeta_utils.h
paddle/phi/core/infermeta_utils.h
+33
-27
paddle/phi/core/type_defs.h
paddle/phi/core/type_defs.h
+0
-29
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+0
-8
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+0
-5
paddle/phi/tests/core/test_meta_fn_utils.cc
paddle/phi/tests/core/test_meta_fn_utils.cc
+0
-26
未找到文件。
paddle/phi/core/infermeta_utils.cc
浏览文件 @
c2a05a90
...
...
@@ -30,7 +30,7 @@ void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
outputs_
.
emplace_back
(
std
::
move
(
output
));
output_range_
.
emplace_back
(
std
::
pair
<
int
,
int
>
(
index
,
index
+
1
));
}
void
InferMetaContext
::
EmplaceBackAttr
(
paddle
::
any
attr
)
{
void
InferMetaContext
::
EmplaceBackAttr
(
Attribute
attr
)
{
attrs_
.
emplace_back
(
std
::
move
(
attr
));
}
...
...
@@ -120,6 +120,38 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
return
result
;
}
template
<
typename
AttrType
>
const
AttrType
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
{
try
{
return
paddle
::
get
<
AttrType
>
(
attrs_
.
at
(
idx
));
}
catch
(
paddle
::
bad_variant_access
const
&
e
)
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`."
,
std
::
type_index
(
typeid
(
AttrType
)).
name
()));
}
}
template
const
bool
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
int
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
int64_t
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
float
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
double
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
std
::
string
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
std
::
vector
<
bool
>
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
std
::
vector
<
int
>
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
std
::
vector
<
int64_t
>
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
std
::
vector
<
float
>
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
std
::
vector
<
double
>
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
std
::
vector
<
std
::
string
>
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
Scalar
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
std
::
vector
<
Scalar
>
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
IntArray
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
DataType
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
DataLayout
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
template
const
Place
&
InferMetaContext
::
AttrAt
(
size_t
idx
)
const
;
MetaFnFactory
&
MetaFnFactory
::
Instance
()
{
static
MetaFnFactory
g_meta_fn_map
;
return
g_meta_fn_map
;
...
...
paddle/phi/core/infermeta_utils.h
浏览文件 @
c2a05a90
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/meta_tensor.h"
...
...
@@ -41,7 +42,7 @@ class InferMetaContext {
void
EmplaceBackInput
(
MetaTensor
input
);
void
EmplaceBackOutput
(
MetaTensor
output
);
void
EmplaceBackAttr
(
paddle
::
any
attr
);
void
EmplaceBackAttr
(
Attribute
attr
);
void
EmplaceBackInputs
(
paddle
::
SmallVector
<
MetaTensor
,
phi
::
kInputSmallVectorSize
>
inputs
);
...
...
@@ -61,17 +62,7 @@ class InferMetaContext {
size_t
end
);
template
<
typename
AttrType
>
AttrType
AttrAt
(
size_t
idx
)
{
try
{
return
paddle
::
any_cast
<
AttrType
>
(
attrs_
.
at
(
idx
));
}
catch
(
paddle
::
bad_any_cast
&
e
)
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`, but actual attribute type is `%s`."
,
std
::
type_index
(
typeid
(
AttrType
)).
name
(),
std
::
type_index
(
attrs_
.
at
(
idx
).
type
()).
name
()));
}
}
const
AttrType
&
AttrAt
(
size_t
idx
)
const
;
const
std
::
pair
<
int
,
int
>&
InputRangeAt
(
size_t
idx
)
const
;
const
std
::
pair
<
int
,
int
>&
OutputRangeAt
(
size_t
idx
)
const
;
...
...
@@ -81,7 +72,7 @@ class InferMetaContext {
protected:
MetaConfig
config_
;
paddle
::
SmallVector
<
paddle
::
any
,
kAttrSmallVectorSize
>
attrs_
;
paddle
::
SmallVector
<
Attribute
,
kAttrSmallVectorSize
>
attrs_
;
paddle
::
SmallVector
<
std
::
pair
<
int
,
int
>
,
phi
::
kInputSmallVectorSize
>
input_range_
;
...
...
@@ -111,6 +102,21 @@ class InferMetaContext {
} \
}
#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
template <typename... Tail> \
struct InferMetaFnCallHelper<const attr_type&, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"InferMeta's Attributes should appear before Outputs."); \
const attr_type& arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}
template
<
typename
T
>
struct
InferMetaTypeTag
{};
...
...
@@ -201,27 +207,27 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};
// TODO(chenweihang): support other attr type later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
bool
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
int64_t
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
float
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
std
::
string
&
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
bool
>&
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
int
>&
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
int64_t
>&
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
float
>&
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
double
>&
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
std
::
string
>&
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataType
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
Backend
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
DataLayout
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
Scalar
&
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE
(
const
IntArray
&
);
// TODO(chenweihang): support vector<MetaTensor> input later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
std
::
string
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
Scalar
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
IntArray
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
std
::
vector
<
bool
>
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
std
::
vector
<
int
>
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
std
::
vector
<
int64_t
>
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
std
::
vector
<
float
>
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
std
::
vector
<
double
>
);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF
(
std
::
vector
<
std
::
string
>
);
template
<
typename
...
Tail
>
struct
InferMetaFnCallHelper
<
MetaTensor
*
,
Tail
...
>
{
...
...
paddle/phi/core/type_defs.h
浏览文件 @
c2a05a90
...
...
@@ -18,37 +18,8 @@
#include <string>
#include <vector>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/variant.h"
namespace
phi
{
class
Place
;
// NOTE: Add needed type in the future
using
Attribute
=
paddle
::
variant
<
bool
,
int
,
int64_t
,
float
,
double
,
std
::
string
,
std
::
vector
<
bool
>
,
std
::
vector
<
int
>
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
float
>
,
std
::
vector
<
double
>
,
std
::
vector
<
std
::
string
>
,
Scalar
,
std
::
vector
<
Scalar
>
,
IntArray
,
DataType
,
DataLayout
,
Place
>
;
class
Kernel
;
class
KernelKey
;
class
KernelArgsDef
;
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
c2a05a90
...
...
@@ -228,13 +228,6 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
out
->
set_dtype
(
x
.
dtype
());
}
void
CopyToInferMeta
(
const
MetaTensor
&
x
,
Backend
backend
,
bool
blocking
,
MetaTensor
*
out
)
{
UnchangedInferMeta
(
x
,
out
);
}
void
CreateLikeInferMeta
(
const
MetaTensor
&
x
,
DataType
dtype
,
MetaTensor
*
out
)
{
out
->
set_dims
(
x
.
dims
());
out
->
set_dtype
(
dtype
==
DataType
::
UNDEFINED
?
x
.
dtype
()
:
dtype
);
...
...
@@ -3008,6 +3001,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
}
// namespace phi
PD_REGISTER_INFER_META_FN
(
copy_to
,
phi
::
CopyToInferMeta
);
PD_REGISTER_INFER_META_FN
(
flatten
,
phi
::
FlattenInferMeta
);
PD_REGISTER_INFER_META_FN
(
split
,
phi
::
SplitInferMeta
);
paddle/phi/infermeta/unary.h
浏览文件 @
c2a05a90
...
...
@@ -58,11 +58,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void
CholeskyInferMeta
(
const
MetaTensor
&
x
,
bool
upper
,
MetaTensor
*
out
);
void
CopyToInferMeta
(
const
MetaTensor
&
x
,
Backend
backend
,
bool
blocking
,
MetaTensor
*
out
);
void
CreateLikeInferMeta
(
const
MetaTensor
&
x
,
DataType
dtype
,
MetaTensor
*
out
);
void
CumsumInferMeta
(
const
MetaTensor
&
x
,
...
...
paddle/phi/tests/core/test_meta_fn_utils.cc
浏览文件 @
c2a05a90
...
...
@@ -60,32 +60,6 @@ TEST(MetaFnFactory, InferMetaFnExists) {
EXPECT_EQ
(
dense_out1
.
dims
()[
1
],
dense_out2
.
dims
()[
1
]);
}
TEST
(
MetaFnFactory
,
CopyInferMetaFn
)
{
phi
::
DenseTensor
dense_x
;
dense_x
.
Resize
({
3
,
4
});
phi
::
MetaTensor
meta_x
(
&
dense_x
);
phi
::
DenseTensor
dense_out1
;
phi
::
MetaTensor
meta_out
(
&
dense_out1
);
phi
::
UnchangedInferMeta
(
meta_x
,
&
meta_out
);
auto
shared_meat_x
=
phi
::
MetaTensor
(
&
dense_x
);
phi
::
DenseTensor
dense_out2
;
auto
shared_meta_out
=
phi
::
MetaTensor
(
&
dense_out2
);
phi
::
InferMetaContext
ctx
;
ctx
.
EmplaceBackInput
(
shared_meat_x
);
ctx
.
EmplaceBackAttr
(
Backend
::
CPU
);
ctx
.
EmplaceBackAttr
(
false
);
ctx
.
EmplaceBackOutput
(
shared_meta_out
);
ctx
.
SetMetaConfig
({
/*is_runtime =*/
true
,
/*is_run_mkldnn_kernel=*/
false
});
phi
::
MetaFnFactory
::
Instance
().
Get
(
"copy_to"
)(
&
ctx
);
EXPECT_EQ
(
dense_out1
.
dims
().
size
(),
dense_out2
.
dims
().
size
());
EXPECT_EQ
(
dense_out1
.
dims
()[
0
],
dense_out2
.
dims
()[
0
]);
EXPECT_EQ
(
dense_out1
.
dims
()[
1
],
dense_out2
.
dims
()[
1
]);
}
TEST
(
MetaFnFactory
,
SplitInferMetaFn
)
{
phi
::
DenseTensor
dense_x
;
dense_x
.
Resize
({
4
,
10
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录