Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
SummerGao.
Paddle
提交
75ae426a
P
Paddle
项目概览
SummerGao.
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
75ae426a
编写于
7月 02, 2018
作者:
Y
yuyang18
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'feature/change_op_kernel_to_func' into feature/fix_reshape_op_size
上级
1ce478f1
3b00ed81
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
15 addition
and
8 deletion
+15
-8
paddle/fluid/framework/op_registry.h
paddle/fluid/framework/op_registry.h
+12
-5
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+1
-1
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+2
-2
未找到文件。
paddle/fluid/framework/op_registry.h
浏览文件 @
75ae426a
...
...
@@ -76,8 +76,9 @@ class OpRegistry {
template
<
typename
PlaceType
,
bool
at_end
,
size_t
I
,
typename
...
KernelType
>
struct
OpKernelRegistrarFunctor
;
template
<
typename
PlaceType
,
typename
T
,
typename
KernelType
>
inline
void
RegisterKernelClass
(
const
char
*
op_type
,
const
char
*
library_type
)
{
template
<
typename
PlaceType
,
typename
T
,
typename
Func
>
inline
void
RegisterKernelClass
(
const
char
*
op_type
,
const
char
*
library_type
,
Func
func
)
{
std
::
string
library
(
library_type
);
std
::
string
data_layout
=
"ANYLAYOUT"
;
if
(
library
==
"MKLDNN"
)
{
...
...
@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type) {
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
(),
StringToDataLayout
(
data_layout
),
StringToLibraryType
(
library_type
));
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
]
.
reset
(
new
KernelType
())
;
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
]
=
func
;
}
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelTypes
>
...
...
@@ -96,7 +97,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
RegisterKernelClass
<
PlaceType
,
T
,
KERNEL_TYPE
>
(
op_type
,
library_type
);
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
[](
const
framework
::
ExecutionContext
&
ctx
)
{
KERNEL_TYPE
().
Compute
(
ctx
);
});
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelTypes
...
>>::
value
;
OpKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelTypes
...
>
func
;
...
...
@@ -150,7 +154,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
std
::
tuple
<
DataTypeAndKernelType
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
RegisterKernelClass
<
PlaceType
,
T
,
KERNEL_TYPE
>
(
op_type
,
library_type
);
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
[](
const
framework
::
ExecutionContext
&
ctx
)
{
KERNEL_TYPE
().
Compute
(
ctx
);
});
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
DataTypeAndKernelType
...
>>::
value
;
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
75ae426a
...
...
@@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx
=
pool
.
Get
(
expected_kernel_key
.
place_
);
}
kernel_iter
->
second
->
Compute
(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
));
kernel_iter
->
second
(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
));
if
(
!
transfered_inplace_vars
.
empty
())
{
// there is inplace variable has been transfered.
...
...
paddle/fluid/framework/operator.h
浏览文件 @
75ae426a
...
...
@@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase {
class
OperatorWithKernel
:
public
OperatorBase
{
public:
using
OpKernelFunc
=
std
::
function
<
void
(
const
ExecutionContext
&
)
>
;
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelType
,
std
::
unique_ptr
<
OpKernelBase
>
,
OpKernelType
::
Hash
>
;
std
::
unordered_map
<
OpKernelType
,
OpKernelFunc
,
OpKernelType
::
Hash
>
;
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录