Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8198cad7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8198cad7
编写于
12月 14, 2021
作者:
Y
YuanRisheng
提交者:
GitHub
12月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove KernelName (#38082)
上级
4c1e27cc
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
12 addition
and
92 deletion
+12
-92
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+1
-1
paddle/fluid/imperative/prepared_operator.cc
paddle/fluid/imperative/prepared_operator.cc
+1
-1
paddle/pten/api/lib/kernel_dispatch.h
paddle/pten/api/lib/kernel_dispatch.h
+1
-1
paddle/pten/core/kernel_factory.cc
paddle/pten/core/kernel_factory.cc
+3
-3
paddle/pten/core/kernel_factory.h
paddle/pten/core/kernel_factory.h
+5
-73
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+1
-1
paddle/pten/tests/core/test_kernel_factory.cc
paddle/pten/tests/core/test_kernel_factory.cc
+0
-12
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
8198cad7
...
...
@@ -1275,7 +1275,7 @@ void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
kernel_type_
.
reset
(
new
OpKernelType
(
std
::
move
(
InnerGetExpectedKernelType
(
ctx
))));
auto
pt_kernel_name
=
pt
en
::
KernelName
(
pt_kernel_signature_
->
name
)
;
auto
pt_kernel_name
=
pt
_kernel_signature_
->
name
;
auto
pt_kernel_key
=
TransOpKernelTypeToPtenKernelKey
(
*
kernel_type_
.
get
());
pt_kernel_
.
reset
(
new
pten
::
Kernel
(
pten
::
KernelFactory
::
Instance
().
SelectKernel
(
...
...
paddle/fluid/imperative/prepared_operator.cc
浏览文件 @
8198cad7
...
...
@@ -165,7 +165,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto
pt_kernel_signature
=
op
.
GetExpectedPtenKernelArgs
(
dygraph_exe_ctx
);
VLOG
(
6
)
<<
framework
::
KernelSignatureToString
(
pt_kernel_signature
);
auto
pt_kernel_name
=
pt
en
::
KernelName
(
pt_kernel_signature
.
name
)
;
auto
pt_kernel_name
=
pt
_kernel_signature
.
name
;
auto
pt_kernel_key
=
TransOpKernelTypeToPtenKernelKey
(
expected_kernel_key
);
auto
pt_kernel
=
pten
::
KernelFactory
::
Instance
().
SelectKernel
(
pt_kernel_name
,
pt_kernel_key
);
...
...
paddle/pten/api/lib/kernel_dispatch.h
浏览文件 @
8198cad7
...
...
@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h"
// TODO(chenweihang): split Ke
rnelName, Ke
y, Kernel, Factory into diff files
// TODO(chenweihang): split Key, Kernel, Factory into diff files
#include "paddle/pten/core/kernel_factory.h"
// See Note [ Why still include the fluid headers? ]
...
...
paddle/pten/core/kernel_factory.cc
浏览文件 @
8198cad7
...
...
@@ -37,7 +37,7 @@ KernelFactory& KernelFactory::Instance() {
return
g_op_kernel_factory
;
}
Kernel
KernelFactory
::
SelectKernel
(
const
KernelName
&
kernel_name
,
Kernel
KernelFactory
::
SelectKernel
(
const
std
::
string
&
kernel_name
,
const
KernelKey
&
kernel_key
)
const
{
auto
iter
=
kernels_
.
find
(
kernel_name
);
if
(
iter
==
kernels_
.
end
())
{
...
...
@@ -51,7 +51,7 @@ Kernel KernelFactory::SelectKernel(const KernelName& kernel_name,
}
const
Kernel
&
KernelFactory
::
SelectKernelOrThrowError
(
const
KernelName
&
kernel_name
,
const
KernelKey
&
kernel_key
)
const
{
const
std
::
string
&
kernel_name
,
const
KernelKey
&
kernel_key
)
const
{
auto
iter
=
kernels_
.
find
(
kernel_name
);
PADDLE_ENFORCE_NE
(
iter
,
kernels_
.
end
(),
...
...
@@ -78,7 +78,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
}
const
Kernel
&
KernelFactory
::
SelectKernelOrThrowError
(
const
KernelName
&
kernel_name
,
const
std
::
string
&
kernel_name
,
Backend
backend
,
DataLayout
layout
,
DataType
dtype
)
const
{
...
...
paddle/pten/core/kernel_factory.h
浏览文件 @
8198cad7
...
...
@@ -51,61 +51,6 @@ class KernelContext;
using
KernelFn
=
void
(
*
)(
KernelContext
*
ctx
);
class
KernelName
final
{
public:
KernelName
(
std
::
string
name
,
std
::
string
overload_name
)
:
name_
(
std
::
move
(
name
)),
overload_name_
(
std
::
move
(
overload_name
))
{}
KernelName
(
const
std
::
string
&
kernel_name
)
{
ParseNameAndOverloadNameFromString
(
kernel_name
);
}
KernelName
(
const
char
*
kernel_name
)
{
std
::
string
kernel_name_str
(
kernel_name
);
ParseNameAndOverloadNameFromString
(
kernel_name_str
);
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
const
std
::
string
&
overload_name
()
const
{
return
overload_name_
;
}
struct
Hash
{
size_t
operator
()(
const
KernelName
&
kernel_name
)
const
{
return
std
::
hash
<
std
::
string
>
()(
kernel_name
.
name
())
^
(
std
::
hash
<
std
::
string
>
()(
kernel_name
.
overload_name
())
<<
1
);
}
};
size_t
hash_value
()
const
{
return
Hash
()(
*
this
);
}
bool
operator
<
(
const
KernelName
&
kernel_name
)
const
{
return
hash_value
()
<
kernel_name
.
hash_value
();
}
bool
operator
==
(
const
KernelName
&
kernel_name
)
const
{
return
hash_value
()
==
kernel_name
.
hash_value
();
}
bool
operator
!=
(
const
KernelName
&
kernel_name
)
const
{
return
hash_value
()
!=
kernel_name
.
hash_value
();
}
private:
void
ParseNameAndOverloadNameFromString
(
const
std
::
string
&
kernel_name
)
{
size_t
pos
=
kernel_name
.
find_first_of
(
'.'
);
if
(
pos
==
std
::
string
::
npos
)
{
name_
=
kernel_name
;
overload_name_
=
""
;
}
else
{
name_
=
kernel_name
.
substr
(
0
,
pos
);
overload_name_
=
kernel_name
.
substr
(
pos
+
1
,
kernel_name
.
size
());
}
}
// TODO(chenweihang): use string_view to improve performance later
std
::
string
name_
;
std
::
string
overload_name_
;
};
class
KernelKey
{
public:
KernelKey
()
=
default
;
...
...
@@ -265,9 +210,8 @@ class KernelFactory {
public:
// replaced by paddle::flat_hash_map later
using
KernelMap
=
paddle
::
flat_hash_map
<
KernelName
,
paddle
::
flat_hash_map
<
KernelKey
,
Kernel
,
KernelKey
::
Hash
>
,
KernelName
::
Hash
>
;
std
::
string
,
paddle
::
flat_hash_map
<
KernelKey
,
Kernel
,
KernelKey
::
Hash
>>
;
static
KernelFactory
&
Instance
();
...
...
@@ -277,15 +221,15 @@ class KernelFactory {
return
kernels_
.
find
(
TransToPtenKernelName
(
op_type
))
!=
kernels_
.
end
();
}
const
Kernel
&
SelectKernelOrThrowError
(
const
KernelName
&
kernel_name
,
const
Kernel
&
SelectKernelOrThrowError
(
const
std
::
string
&
kernel_name
,
const
KernelKey
&
kernel_key
)
const
;
const
Kernel
&
SelectKernelOrThrowError
(
const
KernelName
&
kernel_name
,
const
Kernel
&
SelectKernelOrThrowError
(
const
std
::
string
&
kernel_name
,
Backend
backend
,
DataLayout
layout
,
DataType
dtype
)
const
;
Kernel
SelectKernel
(
const
KernelName
&
kernel_name
,
Kernel
SelectKernel
(
const
std
::
string
&
kernel_name
,
const
KernelKey
&
kernel_key
)
const
;
private:
...
...
@@ -294,18 +238,6 @@ class KernelFactory {
KernelMap
kernels_
;
};
/** operator << overload **/
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelName
&
kernel_name
)
{
if
(
kernel_name
.
overload_name
().
empty
())
{
os
<<
kernel_name
.
name
();
}
else
{
os
<<
kernel_name
.
name
()
<<
"."
<<
kernel_name
.
overload_name
();
}
return
os
;
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelKey
&
kernel_key
)
{
os
<<
"("
<<
kernel_key
.
backend
()
<<
", "
<<
kernel_key
.
layout
()
<<
", "
<<
kernel_key
.
dtype
()
<<
")"
;
...
...
paddle/pten/core/kernel_registry.h
浏览文件 @
8198cad7
...
...
@@ -143,7 +143,7 @@ struct KernelRegistrar {
KernelArgsDefFn
args_def_fn
,
KernelFn
kernel_fn
,
void
*
variadic_kernel_fn
)
{
KernelName
kernel_name
(
kernel_name_cstr
);
std
::
string
kernel_name
(
kernel_name_cstr
);
KernelKey
kernel_key
(
backend
,
layout
,
dtype
);
Kernel
kernel
(
kernel_fn
,
variadic_kernel_fn
);
args_parse_fn
(
kernel_key
,
kernel
.
mutable_args_def
());
...
...
paddle/pten/tests/core/test_kernel_factory.cc
浏览文件 @
8198cad7
...
...
@@ -24,18 +24,6 @@ namespace tests {
// TODO(chenweihang): add more unittests later
TEST
(
KernelName
,
ConstructAndOStream
)
{
std
::
ostringstream
oss
;
oss
<<
pten
::
KernelName
(
"scale"
,
"host"
);
EXPECT_EQ
(
oss
.
str
(),
"scale.host"
);
pten
::
KernelName
kernel_name1
(
"scale.host"
);
EXPECT_EQ
(
kernel_name1
.
name
(),
"scale"
);
EXPECT_EQ
(
kernel_name1
.
overload_name
(),
"host"
);
pten
::
KernelName
kernel_name2
(
"scale.host"
);
EXPECT_EQ
(
kernel_name2
.
name
(),
"scale"
);
EXPECT_EQ
(
kernel_name2
.
overload_name
(),
"host"
);
}
TEST
(
KernelKey
,
ConstructAndOStream
)
{
pten
::
KernelKey
key
(
pten
::
Backend
::
CPU
,
pten
::
DataLayout
::
NCHW
,
pten
::
DataType
::
FLOAT32
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录