Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
73b4d1aa
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
73b4d1aa
编写于
12月 05, 2018
作者:
X
Xin Pan
提交者:
GitHub
12月 05, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14742 from panyx0718/infer2
support customized kernel selection
上级
21c0f874
82d68281
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
235 addition
and
84 deletion
+235
-84
cmake/operators.cmake
cmake/operators.cmake
+2
-0
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+3
-2
paddle/fluid/framework/op_kernel_type.cc
paddle/fluid/framework/op_kernel_type.cc
+54
-0
paddle/fluid/framework/op_kernel_type.h
paddle/fluid/framework/op_kernel_type.h
+30
-29
paddle/fluid/framework/op_registry.h
paddle/fluid/framework/op_registry.h
+85
-45
paddle/fluid/framework/operator_test.cc
paddle/fluid/framework/operator_test.cc
+42
-1
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+9
-5
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+8
-2
paddle/fluid/operators/conv_op.h
paddle/fluid/operators/conv_op.h
+2
-0
未找到文件。
cmake/operators.cmake
浏览文件 @
73b4d1aa
...
@@ -166,6 +166,8 @@ function(op_library TARGET)
...
@@ -166,6 +166,8 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator
# Append first implemented MKLDNN activation operator
if
(
${
MKLDNN_FILE
}
STREQUAL
"activation_mkldnn_op"
)
if
(
${
MKLDNN_FILE
}
STREQUAL
"activation_mkldnn_op"
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(relu, MKLDNN);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(relu, MKLDNN);
\n
"
)
elseif
(
${
MKLDNN_FILE
}
STREQUAL
"conv_mkldnn_op"
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);
\n
"
)
else
()
else
()
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(
${
TARGET
}
, MKLDNN);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(
${
TARGET
}
, MKLDNN);
\n
"
)
endif
()
endif
()
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
73b4d1aa
...
@@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
...
@@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library
(
shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context
)
cc_library
(
shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context
)
cc_library
(
transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context
)
cc_library
(
transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context
)
cc_library
(
op_kernel_type SRCS op_kernel_type.cc DEPS device_context place
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache
)
shape_inference data_transform lod_tensor profiler transfer_scope_cache
op_kernel_type
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry device_context
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry device_context
)
...
@@ -191,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
...
@@ -191,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library
(
selected_rows SRCS selected_rows.cc DEPS tensor
)
cc_library
(
selected_rows SRCS selected_rows.cc DEPS tensor
)
cc_test
(
selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows
)
cc_test
(
selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows
)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto
)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto
op_kernel_type
)
cc_test
(
cow_ptr_tests SRCS details/cow_ptr_test.cc
)
cc_test
(
cow_ptr_tests SRCS details/cow_ptr_test.cc
)
cc_test
(
tuple_test SRCS tuple_test.cc
)
cc_test
(
tuple_test SRCS tuple_test.cc
)
...
...
paddle/fluid/framework/op_kernel_type.cc
0 → 100644
浏览文件 @
73b4d1aa
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_kernel_type.h"
namespace
paddle
{
namespace
framework
{
size_t
OpKernelType
::
Hash
::
operator
()(
const
OpKernelType
&
key
)
const
{
int
cur_loc
=
0
;
int
place
=
key
.
place_
.
which
();
cur_loc
+=
OpKernelType
::
kPlaceBits
;
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
)
<<
cur_loc
;
cur_loc
+=
OpKernelType
::
kPrimaryDTypeBits
;
int
data_layout
=
static_cast
<
int
>
(
key
.
data_layout_
)
<<
cur_loc
;
cur_loc
+=
OpKernelType
::
kLayoutBits
;
int
library_type
=
static_cast
<
int
>
(
key
.
library_type_
)
<<
cur_loc
;
cur_loc
+=
OpKernelType
::
kLibBits
;
int
customized_value
=
key
.
customized_type_value_
;
PADDLE_ENFORCE
(
customized_value
<
(
1
<<
OpKernelType
::
kCustomizeBits
));
customized_value
=
customized_value
<<
cur_loc
;
cur_loc
+=
OpKernelType
::
kCustomizeBits
;
PADDLE_ENFORCE
(
cur_loc
<
64
);
std
::
hash
<
int
>
hasher
;
return
hasher
(
place
+
data_type
+
data_layout
+
library_type
+
customized_value
);
}
bool
OpKernelType
::
operator
==
(
const
OpKernelType
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
data_type_
==
o
.
data_type_
&&
data_layout_
==
o
.
data_layout_
&&
library_type_
==
o
.
library_type_
&&
customized_type_value_
==
o
.
customized_type_value_
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/op_kernel_type.h
浏览文件 @
73b4d1aa
...
@@ -24,54 +24,55 @@ limitations under the License. */
...
@@ -24,54 +24,55 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
struct
OpKernelType
{
class
OpKernelType
{
struct
Hash
{
public:
size_t
operator
()(
const
OpKernelType
&
key
)
const
{
constexpr
static
int
kDefaultCustomizedTypeValue
=
0
;
int
place
=
key
.
place_
.
which
();
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
)
<<
LEFT_SHIFT
;
int
data_layout
=
static_cast
<
int
>
(
key
.
data_layout_
)
<<
(
LEFT_SHIFT
*
2
);
int
library_type
=
static_cast
<
int
>
(
key
.
library_type_
)
<<
(
LEFT_SHIFT
*
3
);
std
::
hash
<
int
>
hasher
;
return
hasher
(
place
+
data_type
+
data_layout
+
library_type
);
}
};
// place, data_type, library_type kinds less than 2^8
// In total should be smaller than 64.
constexpr
static
int
LEFT_SHIFT
=
8
;
constexpr
static
int
kPlaceBits
=
4
;
constexpr
static
int
kPrimaryDTypeBits
=
8
;
proto
::
VarType
::
Type
data_type_
;
constexpr
static
int
kLayoutBits
=
4
;
DataLayout
data_layout_
;
constexpr
static
int
kLibBits
=
4
;
platform
::
Place
place_
;
constexpr
static
int
kCustomizeBits
=
4
;
LibraryType
library_type_
;
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
platform
::
Place
place
,
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
platform
::
Place
place
,
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
,
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
,
LibraryType
library_type
=
LibraryType
::
kPlain
)
LibraryType
library_type
=
LibraryType
::
kPlain
,
int
customized_type_value
=
kDefaultCustomizedTypeValue
)
:
data_type_
(
data_type
),
:
data_type_
(
data_type
),
data_layout_
(
data_layout
),
data_layout_
(
data_layout
),
place_
(
place
),
place_
(
place
),
library_type_
(
library_type
)
{}
library_type_
(
library_type
),
customized_type_value_
(
customized_type_value
)
{}
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
const
platform
::
DeviceContext
&
dev_ctx
,
const
platform
::
DeviceContext
&
dev_ctx
,
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
,
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
,
LibraryType
library_type
=
LibraryType
::
kPlain
)
LibraryType
library_type
=
LibraryType
::
kPlain
,
int
customized_type_value
=
kDefaultCustomizedTypeValue
)
:
data_type_
(
data_type
),
:
data_type_
(
data_type
),
data_layout_
(
data_layout
),
data_layout_
(
data_layout
),
place_
(
dev_ctx
.
GetPlace
()),
place_
(
dev_ctx
.
GetPlace
()),
library_type_
(
library_type
)
{}
library_type_
(
library_type
),
customized_type_value_
(
customized_type_value
)
{}
virtual
~
OpKernelType
()
{}
struct
Hash
{
size_t
operator
()(
const
OpKernelType
&
key
)
const
;
};
size_t
hash_key
()
const
{
return
Hash
()(
*
this
);
}
size_t
hash_key
()
const
{
return
Hash
()(
*
this
);
}
bool
operator
==
(
const
OpKernelType
&
o
)
const
{
bool
operator
==
(
const
OpKernelType
&
o
)
const
;
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
data_type_
==
o
.
data_type_
&&
data_layout_
==
o
.
data_layout_
&&
library_type_
==
o
.
library_type_
;
}
bool
operator
!=
(
const
OpKernelType
&
o
)
const
{
return
!
(
*
this
==
o
);
}
bool
operator
!=
(
const
OpKernelType
&
o
)
const
{
return
!
(
*
this
==
o
);
}
proto
::
VarType
::
Type
data_type_
;
DataLayout
data_layout_
;
platform
::
Place
place_
;
LibraryType
library_type_
;
int
customized_type_value_
;
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
...
...
paddle/fluid/framework/op_registry.h
浏览文件 @
73b4d1aa
...
@@ -35,6 +35,7 @@ limitations under the License. */
...
@@ -35,6 +35,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
Registrar
{
class
Registrar
{
public:
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// In our design, various kinds of classes, e.g., operators and kernels,
...
@@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor;
...
@@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor;
template
<
typename
PlaceType
,
typename
T
,
typename
Func
>
template
<
typename
PlaceType
,
typename
T
,
typename
Func
>
inline
void
RegisterKernelClass
(
const
char
*
op_type
,
const
char
*
library_type
,
inline
void
RegisterKernelClass
(
const
char
*
op_type
,
const
char
*
library_type
,
Func
func
)
{
int
customized_type_value
,
Func
func
)
{
std
::
string
library
(
library_type
);
std
::
string
library
(
library_type
);
std
::
string
data_layout
=
"ANYLAYOUT"
;
std
::
string
data_layout
=
"ANYLAYOUT"
;
if
(
library
==
"MKLDNN"
)
{
if
(
library
==
"MKLDNN"
)
{
...
@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type,
...
@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type,
}
}
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
(),
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
(),
StringToDataLayout
(
data_layout
),
StringToDataLayout
(
data_layout
),
StringToLibraryType
(
library_type
));
StringToLibraryType
(
library_type
)
,
customized_type_value
);
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
]
=
func
;
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
]
=
func
;
}
}
...
@@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
...
@@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using
KERNEL_TYPE
=
using
KERNEL_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelTypes
...
>>::
type
;
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelTypes
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
const
{
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
RegisterKernelClass
<
PlaceType
,
T
>
(
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
[](
const
framework
::
ExecutionContext
&
ctx
)
{
op_type
,
library_type
,
customized_type_value
,
[](
const
framework
::
ExecutionContext
&
ctx
)
{
KERNEL_TYPE
().
Compute
(
ctx
);
KERNEL_TYPE
().
Compute
(
ctx
);
});
});
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelTypes
...
>>::
value
;
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelTypes
...
>>::
value
;
OpKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelTypes
...
>
OpKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelTypes
...
>
func
;
func
;
func
(
op_type
,
library_type
);
func
(
op_type
,
library_type
,
customized_type_value
);
}
}
};
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelType
>
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelType
>
struct
OpKernelRegistrarFunctor
<
PlaceType
,
true
,
I
,
KernelType
...
>
{
struct
OpKernelRegistrarFunctor
<
PlaceType
,
true
,
I
,
KernelType
...
>
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{}
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
const
{}
};
};
// User can register many kernel in one place. The data type could be
// User can register many kernel in one place. The data type could be
...
@@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
...
@@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
template
<
typename
PlaceType
,
typename
...
KernelType
>
template
<
typename
PlaceType
,
typename
...
KernelType
>
class
OpKernelRegistrar
:
public
Registrar
{
class
OpKernelRegistrar
:
public
Registrar
{
public:
public:
explicit
OpKernelRegistrar
(
const
char
*
op_type
,
const
char
*
library_type
)
{
explicit
OpKernelRegistrar
(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
{
OpKernelRegistrarFunctor
<
PlaceType
,
false
,
0
,
KernelType
...
>
func
;
OpKernelRegistrarFunctor
<
PlaceType
,
false
,
0
,
KernelType
...
>
func
;
func
(
op_type
,
library_type
);
func
(
op_type
,
library_type
,
customized_type_value
);
}
}
};
};
...
@@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx;
...
@@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx;
template
<
typename
PlaceType
,
typename
...
DataTypeAndKernelType
>
template
<
typename
PlaceType
,
typename
...
DataTypeAndKernelType
>
class
OpKernelRegistrarEx
:
public
Registrar
{
class
OpKernelRegistrarEx
:
public
Registrar
{
public:
public:
explicit
OpKernelRegistrarEx
(
const
char
*
op_type
,
const
char
*
library_type
)
{
explicit
OpKernelRegistrarEx
(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
{
OpKernelRegistrarFunctorEx
<
PlaceType
,
false
,
0
,
DataTypeAndKernelType
...
>
OpKernelRegistrarFunctorEx
<
PlaceType
,
false
,
0
,
DataTypeAndKernelType
...
>
func
;
func
;
func
(
op_type
,
library_type
);
func
(
op_type
,
library_type
,
customized_type_value
);
}
}
};
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
struct
OpKernelRegistrarFunctorEx
<
PlaceType
,
true
,
I
,
struct
OpKernelRegistrarFunctorEx
<
PlaceType
,
true
,
I
,
DataTypeAndKernelType
...
>
{
DataTypeAndKernelType
...
>
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{}
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
const
{}
};
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
...
@@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
...
@@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
typename
std
::
tuple_element
<
I
,
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
DataTypeAndKernelType
...
>>::
type
;
std
::
tuple
<
DataTypeAndKernelType
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
,
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
Functor
());
int
customized_type_value
)
const
{
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
customized_type_value
,
Functor
());
constexpr
auto
size
=
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
DataTypeAndKernelType
...
>>::
value
;
std
::
tuple_size
<
std
::
tuple
<
DataTypeAndKernelType
...
>>::
value
;
OpKernelRegistrarFunctorEx
<
PlaceType
,
I
+
2
>=
size
,
I
+
2
,
OpKernelRegistrarFunctorEx
<
PlaceType
,
I
+
2
>=
size
,
I
+
2
,
DataTypeAndKernelType
...
>
DataTypeAndKernelType
...
>
func
;
func
;
func
(
op_type
,
library_type
);
func
(
op_type
,
library_type
,
customized_type_value
);
}
}
};
};
// clang-format off
/**
/**
* check if MACRO is used in GLOBAL NAMESPACE.
* check if MACRO is used in GLOBAL NAMESPACE.
*/
*/
...
@@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
...
@@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
/**
/**
* Macro to register OperatorKernel.
* Macro to register OperatorKernel.
*/
*/
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type, \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
place_class, customized_name, \
__reg_op_kernel_##op_type##_##library_type##__, \
customized_type_value, ...) \
"REGISTER_OP_KERNEL must be called in global namespace"); \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
__reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
"REGISTER_OP_KERNEL must be called in " \
#library_type); \
"global namespace"); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
static ::paddle::framework::OpKernelRegistrar<place_class, \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
__VA_ARGS__> \
return 0; \
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
#op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \
}
}
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( \
op_type, library_type, place_class, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
customized_name, \
__reg_op_kernel_##op_type##_##library_type##__, \
customized_type_value, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
...) \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
__reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
#library_type); \
"REGISTER_OP_KERNEL_EX must be called in " \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
"global namespace"); \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
static ::paddle::framework::OpKernelRegistrarEx<place_class, \
return 0; \
__VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
#op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \
}
}
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
REGISTER_OP_KERNEL_EX( \
__VA_ARGS__)
op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
REGISTER_OP_KERNEL_EX( \
op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
/**
/**
* Macro to mark what Operator and Kernel
* Macro to mark what Operator and Kernel
...
@@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
...
@@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
extern int TouchOpRegistrar_##op_type(); \
extern int TouchOpRegistrar_##op_type(); \
UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
LIBRARY_TYPE, \
__use_op_kernel_##op_type##_##LIBRARY_TYPE##__, \
customized_name) \
"USE_OP_DEVICE_KERNEL must be in global namespace"); \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
extern int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE(); \
__use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##__, \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_ = \
"USE_OP_DEVICE_KERNEL must be in global namespace"); \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE()
extern int \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##DEFAULT_TYPE##_ =
/* NOLINT */
\
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
// TODO(fengjiayi): The following macros
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
// seems ugly, do we have better method?
...
@@ -280,6 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
...
@@ -280,6 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
#define USE_OP(op_type) \
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
USE_OP_KERNEL(op_type)
// clang-format off
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/operator_test.cc
浏览文件 @
73b4d1aa
...
@@ -50,6 +50,8 @@ class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -50,6 +50,8 @@ class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
AddInput
(
"input"
,
"input of test op"
);
AddInput
(
"input"
,
"input of test op"
);
AddOutput
(
"output"
,
"output of test op"
);
AddOutput
(
"output"
,
"output of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
);
AddAttr
<
int
>
(
"kernel_sub_type"
,
"kernels with different implementations."
)
.
SetDefault
(
0
);
AddComment
(
"This is test op"
);
AddComment
(
"This is test op"
);
}
}
};
};
...
@@ -95,6 +97,8 @@ TEST(OperatorBase, all) {
...
@@ -95,6 +97,8 @@ TEST(OperatorBase, all) {
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
static
int
special_type_value
=
1
;
class
OpKernelTestProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
OpKernelTestProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
{
void
Make
()
{
...
@@ -103,11 +107,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -103,11 +107,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
SetDefault
(
1.0
)
.
GreaterThan
(
0.0
);
.
GreaterThan
(
0.0
);
AddAttr
<
int
>
(
"kernel_sub_type"
,
"kernels with different implementations."
)
.
SetDefault
(
0
);
AddComment
(
"This is test op"
);
AddComment
(
"This is test op"
);
}
}
};
};
static
int
cpu_kernel_run_num
=
0
;
static
int
cpu_kernel_run_num
=
0
;
static
int
cpu_kernel2_run_num
=
0
;
class
OpWithKernelTest
:
public
OperatorWithKernel
{
class
OpWithKernelTest
:
public
OperatorWithKernel
{
public:
public:
...
@@ -117,7 +124,10 @@ class OpWithKernelTest : public OperatorWithKernel {
...
@@ -117,7 +124,10 @@ class OpWithKernelTest : public OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
OpKernelType
GetExpectedKernelType
(
OpKernelType
GetExpectedKernelType
(
const
ExecutionContext
&
ctx
)
const
override
{
const
ExecutionContext
&
ctx
)
const
override
{
return
OpKernelType
(
proto
::
VarType
::
FP32
,
ctx
.
GetPlace
());
int
sub_type
=
ctx
.
Attr
<
int
>
(
"kernel_sub_type"
);
return
OpKernelType
(
proto
::
VarType
::
FP32
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
framework
::
LibraryType
::
kPlain
,
sub_type
);
}
}
};
};
...
@@ -132,6 +142,17 @@ class CPUKernelTest : public OpKernel<float> {
...
@@ -132,6 +142,17 @@ class CPUKernelTest : public OpKernel<float> {
}
}
};
};
template
<
typename
T1
,
typename
T2
>
class
CPUKernel2Test
:
public
OpKernel
<
float
>
{
public:
void
Compute
(
const
ExecutionContext
&
ctx
)
const
{
std
::
cout
<<
ctx
.
op
().
DebugString
()
<<
std
::
endl
;
cpu_kernel2_run_num
++
;
ASSERT_EQ
(
ctx
.
op
().
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
(
ctx
.
op
().
Output
(
"y"
),
"OUT1"
);
}
};
class
OpKernelTestMultiInputsProtoAndCheckerMaker
class
OpKernelTestMultiInputsProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
:
public
OpProtoAndCheckerMaker
{
public:
public:
...
@@ -142,6 +163,8 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
...
@@ -142,6 +163,8 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
SetDefault
(
1.0
)
.
GreaterThan
(
0.0
);
.
GreaterThan
(
0.0
);
AddAttr
<
int
>
(
"kernel_sub_type"
,
"kernels with different implementations."
)
.
SetDefault
(
0
);
AddComment
(
"This is test op"
);
AddComment
(
"This is test op"
);
}
}
};
};
...
@@ -189,9 +212,15 @@ class CPUKernalMultiInputsTest : public OpKernel<float> {
...
@@ -189,9 +212,15 @@ class CPUKernalMultiInputsTest : public OpKernel<float> {
REGISTER_OP_WITHOUT_GRADIENT
(
REGISTER_OP_WITHOUT_GRADIENT
(
op_with_kernel
,
paddle
::
framework
::
OpWithKernelTest
,
op_with_kernel
,
paddle
::
framework
::
OpWithKernelTest
,
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
<
float
,
float
>
);
paddle
::
framework
::
CPUKernelTest
<
float
,
float
>
);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
op_with_kernel
,
CPU
,
paddle
::
platform
::
CPUPlace
,
MY_SPECIAL_NAME
,
paddle
::
framework
::
special_type_value
,
paddle
::
framework
::
CPUKernel2Test
<
float
,
float
>
);
// test with single input
// test with single input
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
paddle
::
framework
::
InitDevices
(
true
);
paddle
::
framework
::
InitDevices
(
true
);
...
@@ -211,7 +240,19 @@ TEST(OpKernel, all) {
...
@@ -211,7 +240,19 @@ TEST(OpKernel, all) {
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
0
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
0
);
op
->
Run
(
scope
,
cpu_place
);
op
->
Run
(
scope
,
cpu_place
);
// kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called.
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel2_run_num
,
0
);
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"kernel_sub_type"
);
attr
->
set_type
(
paddle
::
framework
::
proto
::
AttrType
::
INT
);
attr
->
set_i
(
1
);
auto
op2
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op2
->
Run
(
scope
,
cpu_place
);
// kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called.
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel2_run_num
,
1
);
}
}
REGISTER_OP_WITHOUT_GRADIENT
(
REGISTER_OP_WITHOUT_GRADIENT
(
...
...
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
73b4d1aa
...
@@ -491,8 +491,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -491,8 +491,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_KERNEL
(
conv2d
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
conv2d
,
MKLDNN
,
ops
::
ConvMKLDNNOpKernel
<
float
>
);
::
paddle
::
platform
::
CPUPlace
,
FP32
,
ops
::
kConvMKLDNNFP32
,
REGISTER_OP_KERNEL
(
conv2d_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
ops
::
ConvMKLDNNOpKernel
<
float
>
);
ops
::
ConvMKLDNNGradOpKernel
<
float
>
);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
conv2d_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
FP32
,
ops
::
kConvMKLDNNFP32
,
ops
::
ConvMKLDNNGradOpKernel
<
float
>
);
paddle/fluid/operators/conv_op.cc
浏览文件 @
73b4d1aa
...
@@ -74,6 +74,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -74,6 +74,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
ConvOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
ConvOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
int
customized_type_value
=
framework
::
OpKernelType
::
kDefaultCustomizedTypeValue
;
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
...
@@ -89,6 +91,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
...
@@ -89,6 +91,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kMKLDNN
;
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
customized_type_value
=
kConvMKLDNNFP32
;
}
}
#endif
#endif
...
@@ -105,7 +108,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
...
@@ -105,7 +108,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
}
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
);
library
,
customized_type_value
);
}
}
void
Conv2DOpMaker
::
Make
()
{
void
Conv2DOpMaker
::
Make
()
{
...
@@ -342,6 +345,8 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -342,6 +345,8 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
ConvOpGrad
::
GetExpectedKernelType
(
framework
::
OpKernelType
ConvOpGrad
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
int
customized_type_value
=
framework
::
OpKernelType
::
kDefaultCustomizedTypeValue
;
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
...
@@ -357,12 +362,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
...
@@ -357,12 +362,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
customized_type_value
=
kConvMKLDNNFP32
;
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
,
customized_type_value
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/conv_op.h
浏览文件 @
73b4d1aa
...
@@ -27,6 +27,8 @@ namespace paddle {
...
@@ -27,6 +27,8 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
constexpr
int
kConvMKLDNNFP32
=
1
;
constexpr
int
kConvMKLDNNINT8
=
2
;
// Base convolution operator definations for other conv
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
// like operators to reuse the implementation.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录