Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
41c28d54
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41c28d54
编写于
12月 05, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
allow customize kernel selection
test=develop
上级
0e3048db
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
232 addition
and
85 deletion
+232
-85
cmake/operators.cmake
cmake/operators.cmake
+2
-0
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+3
-2
paddle/fluid/framework/op_kernel_type.cc
paddle/fluid/framework/op_kernel_type.cc
+54
-0
paddle/fluid/framework/op_kernel_type.h
paddle/fluid/framework/op_kernel_type.h
+30
-29
paddle/fluid/framework/op_registry.h
paddle/fluid/framework/op_registry.h
+85
-45
paddle/fluid/framework/operator_test.cc
paddle/fluid/framework/operator_test.cc
+43
-3
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+9
-5
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+4
-1
paddle/fluid/operators/conv_op.h
paddle/fluid/operators/conv_op.h
+2
-0
未找到文件。
cmake/operators.cmake
浏览文件 @
41c28d54
...
...
@@ -166,6 +166,8 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator
if
(
${
MKLDNN_FILE
}
STREQUAL
"activation_mkldnn_op"
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(relu, MKLDNN);
\n
"
)
elseif
(
${
MKLDNN_FILE
}
STREQUAL
"conv_mkldnn_op"
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);
\n
"
)
else
()
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(
${
TARGET
}
, MKLDNN);
\n
"
)
endif
()
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
41c28d54
...
...
@@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library
(
shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context
)
cc_library
(
transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context
)
cc_library
(
op_kernel_type SRCS op_kernel_type.cc DEPS device_context place
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache
)
shape_inference data_transform lod_tensor profiler transfer_scope_cache
op_kernel_type
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry device_context
)
...
...
@@ -191,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library
(
selected_rows SRCS selected_rows.cc DEPS tensor
)
cc_test
(
selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows
)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto
)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto
op_kernel_type
)
cc_test
(
cow_ptr_tests SRCS details/cow_ptr_test.cc
)
cc_test
(
tuple_test SRCS tuple_test.cc
)
...
...
paddle/fluid/framework/op_kernel_type.cc
0 → 100644
浏览文件 @
41c28d54
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_kernel_type.h"
namespace
paddle
{
namespace
framework
{
size_t
OpKernelType
::
Hash
::
operator
()(
const
OpKernelType
&
key
)
const
{
int
cur_loc
=
0
;
int
place
=
key
.
place_
.
which
();
cur_loc
+=
OpKernelType
::
kPlaceBits
;
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
)
<<
cur_loc
;
cur_loc
+=
OpKernelType
::
kPrimaryDTypeBits
;
int
data_layout
=
static_cast
<
int
>
(
key
.
data_layout_
)
<<
cur_loc
;
cur_loc
+=
OpKernelType
::
kLayoutBits
;
int
library_type
=
static_cast
<
int
>
(
key
.
library_type_
)
<<
cur_loc
;
cur_loc
+=
OpKernelType
::
kLibBits
;
int
customized_value
=
key
.
customized_type_value_
;
PADDLE_ENFORCE
(
customized_value
<
(
1
<<
OpKernelType
::
kCustomizeBits
));
customized_value
=
customized_value
<<
cur_loc
;
cur_loc
+=
OpKernelType
::
kCustomizeBits
;
PADDLE_ENFORCE
(
cur_loc
<
64
);
std
::
hash
<
int
>
hasher
;
return
hasher
(
place
+
data_type
+
data_layout
+
library_type
+
customized_value
);
}
bool
OpKernelType
::
operator
==
(
const
OpKernelType
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
data_type_
==
o
.
data_type_
&&
data_layout_
==
o
.
data_layout_
&&
library_type_
==
o
.
library_type_
&&
customized_type_value_
==
o
.
customized_type_value_
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/op_kernel_type.h
浏览文件 @
41c28d54
...
...
@@ -24,54 +24,55 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
struct
OpKernelType
{
struct
Hash
{
size_t
operator
()(
const
OpKernelType
&
key
)
const
{
int
place
=
key
.
place_
.
which
();
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
)
<<
LEFT_SHIFT
;
int
data_layout
=
static_cast
<
int
>
(
key
.
data_layout_
)
<<
(
LEFT_SHIFT
*
2
);
int
library_type
=
static_cast
<
int
>
(
key
.
library_type_
)
<<
(
LEFT_SHIFT
*
3
);
std
::
hash
<
int
>
hasher
;
return
hasher
(
place
+
data_type
+
data_layout
+
library_type
);
}
};
class
OpKernelType
{
public:
constexpr
static
int
kDefaultCustomizedTypeValue
=
0
;
// place, data_type, library_type kinds less than 2^8
constexpr
static
int
LEFT_SHIFT
=
8
;
proto
::
VarType
::
Type
data_type_
;
DataLayout
data_layout_
;
platform
::
Place
place_
;
LibraryType
library_type_
;
// In total should be smaller than 64.
constexpr
static
int
kPlaceBits
=
4
;
constexpr
static
int
kPrimaryDTypeBits
=
8
;
constexpr
static
int
kLayoutBits
=
4
;
constexpr
static
int
kLibBits
=
4
;
constexpr
static
int
kCustomizeBits
=
4
;
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
platform
::
Place
place
,
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
,
LibraryType
library_type
=
LibraryType
::
kPlain
)
LibraryType
library_type
=
LibraryType
::
kPlain
,
int
customized_type_value
=
kDefaultCustomizedTypeValue
)
:
data_type_
(
data_type
),
data_layout_
(
data_layout
),
place_
(
place
),
library_type_
(
library_type
)
{}
library_type_
(
library_type
),
customized_type_value_
(
customized_type_value
)
{}
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
const
platform
::
DeviceContext
&
dev_ctx
,
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
,
LibraryType
library_type
=
LibraryType
::
kPlain
)
LibraryType
library_type
=
LibraryType
::
kPlain
,
int
customized_type_value
=
kDefaultCustomizedTypeValue
)
:
data_type_
(
data_type
),
data_layout_
(
data_layout
),
place_
(
dev_ctx
.
GetPlace
()),
library_type_
(
library_type
)
{}
library_type_
(
library_type
),
customized_type_value_
(
customized_type_value
)
{}
virtual
~
OpKernelType
()
{}
struct
Hash
{
size_t
operator
()(
const
OpKernelType
&
key
)
const
;
};
size_t
hash_key
()
const
{
return
Hash
()(
*
this
);
}
bool
operator
==
(
const
OpKernelType
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
data_type_
==
o
.
data_type_
&&
data_layout_
==
o
.
data_layout_
&&
library_type_
==
o
.
library_type_
;
}
bool
operator
==
(
const
OpKernelType
&
o
)
const
;
bool
operator
!=
(
const
OpKernelType
&
o
)
const
{
return
!
(
*
this
==
o
);
}
proto
::
VarType
::
Type
data_type_
;
DataLayout
data_layout_
;
platform
::
Place
place_
;
LibraryType
library_type_
;
int
customized_type_value_
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
...
...
paddle/fluid/framework/op_registry.h
浏览文件 @
41c28d54
...
...
@@ -35,6 +35,7 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
class
Registrar
{
public:
// In our design, various kinds of classes, e.g., operators and kernels,
...
...
@@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor;
template
<
typename
PlaceType
,
typename
T
,
typename
Func
>
inline
void
RegisterKernelClass
(
const
char
*
op_type
,
const
char
*
library_type
,
Func
func
)
{
int
customized_type_value
,
Func
func
)
{
std
::
string
library
(
library_type
);
std
::
string
data_layout
=
"ANYLAYOUT"
;
if
(
library
==
"MKLDNN"
)
{
...
...
@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type,
}
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
(),
StringToDataLayout
(
data_layout
),
StringToLibraryType
(
library_type
));
StringToLibraryType
(
library_type
)
,
customized_type_value
);
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
]
=
func
;
}
...
...
@@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using
KERNEL_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelTypes
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
const
{
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
[](
const
framework
::
ExecutionContext
&
ctx
)
{
op_type
,
library_type
,
customized_type_value
,
[](
const
framework
::
ExecutionContext
&
ctx
)
{
KERNEL_TYPE
().
Compute
(
ctx
);
});
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelTypes
...
>>::
value
;
OpKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelTypes
...
>
func
;
func
(
op_type
,
library_type
);
func
(
op_type
,
library_type
,
customized_type_value
);
}
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelType
>
struct
OpKernelRegistrarFunctor
<
PlaceType
,
true
,
I
,
KernelType
...
>
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{}
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
const
{}
};
// User can register many kernel in one place. The data type could be
...
...
@@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
template
<
typename
PlaceType
,
typename
...
KernelType
>
class
OpKernelRegistrar
:
public
Registrar
{
public:
explicit
OpKernelRegistrar
(
const
char
*
op_type
,
const
char
*
library_type
)
{
explicit
OpKernelRegistrar
(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
{
OpKernelRegistrarFunctor
<
PlaceType
,
false
,
0
,
KernelType
...
>
func
;
func
(
op_type
,
library_type
);
func
(
op_type
,
library_type
,
customized_type_value
);
}
};
...
...
@@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx;
template
<
typename
PlaceType
,
typename
...
DataTypeAndKernelType
>
class
OpKernelRegistrarEx
:
public
Registrar
{
public:
explicit
OpKernelRegistrarEx
(
const
char
*
op_type
,
const
char
*
library_type
)
{
explicit
OpKernelRegistrarEx
(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
{
OpKernelRegistrarFunctorEx
<
PlaceType
,
false
,
0
,
DataTypeAndKernelType
...
>
func
;
func
(
op_type
,
library_type
);
func
(
op_type
,
library_type
,
customized_type_value
);
}
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
struct
OpKernelRegistrarFunctorEx
<
PlaceType
,
true
,
I
,
DataTypeAndKernelType
...
>
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{}
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
const
{}
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
...
...
@@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
DataTypeAndKernelType
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
Functor
());
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
,
int
customized_type_value
)
const
{
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
customized_type_value
,
Functor
());
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
DataTypeAndKernelType
...
>>::
value
;
OpKernelRegistrarFunctorEx
<
PlaceType
,
I
+
2
>=
size
,
I
+
2
,
DataTypeAndKernelType
...
>
func
;
func
(
op_type
,
library_type
);
func
(
op_type
,
library_type
,
customized_type_value
);
}
};
// clang-format off
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
...
...
@@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
/**
* Macro to register OperatorKernel.
*/
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type, \
place_class, customized_name, \
customized_type_value, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
#library_type); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
__reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
"REGISTER_OP_KERNEL must be called in " \
"global namespace"); \
static ::paddle::framework::OpKernelRegistrar<place_class, \
__VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
#op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \
}
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( \
op_type, library_type, place_class, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \
customized_name, \
customized_type_value, \
...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
#library_type); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
__reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
"REGISTER_OP_KERNEL_EX must be called in " \
"global namespace"); \
static ::paddle::framework::OpKernelRegistrarEx<place_class, \
__VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
#op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \
}
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
REGISTER_OP_KERNEL_EX( \
op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
REGISTER_OP_KERNEL_EX( \
op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
/**
* Macro to mark what Operator and Kernel
...
...
@@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
extern int TouchOpRegistrar_##op_type(); \
UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
LIBRARY_TYPE, \
customized_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##LIBRARY_TYPE##_
_,
\
__use_op_kernel_##op_type##_##LIBRARY_TYPE##_
##customized_name##__,
\
"USE_OP_DEVICE_KERNEL must be in global namespace"); \
extern int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE(); \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_ = \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE()
extern int \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##DEFAULT_TYPE##_ =
/* NOLINT */
\
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
...
...
@@ -280,6 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
// clang-format off
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/operator_test.cc
浏览文件 @
41c28d54
...
...
@@ -50,6 +50,8 @@ class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
AddInput
(
"input"
,
"input of test op"
);
AddOutput
(
"output"
,
"output of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
);
AddAttr
<
int
>
(
"kernel_sub_type"
,
"kernels with different implementations."
)
.
SetDefault
(
0
);
AddComment
(
"This is test op"
);
}
};
...
...
@@ -103,11 +105,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
GreaterThan
(
0.0
);
AddAttr
<
int
>
(
"kernel_sub_type"
,
"kernels with different implementations."
)
.
SetDefault
(
0
);
AddComment
(
"This is test op"
);
}
};
static
int
cpu_kernel_run_num
=
0
;
static
int
cpu_kernel2_run_num
=
0
;
class
OpWithKernelTest
:
public
OperatorWithKernel
{
public:
...
...
@@ -117,7 +122,10 @@ class OpWithKernelTest : public OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
OpKernelType
GetExpectedKernelType
(
const
ExecutionContext
&
ctx
)
const
override
{
return
OpKernelType
(
proto
::
VarType
::
FP32
,
ctx
.
GetPlace
());
int
sub_type
=
ctx
.
Attr
<
int
>
(
"kernel_sub_type"
);
return
OpKernelType
(
proto
::
VarType
::
FP32
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
framework
::
LibraryType
::
kPlain
,
sub_type
);
}
};
...
...
@@ -132,6 +140,17 @@ class CPUKernelTest : public OpKernel<float> {
}
};
template
<
typename
T1
,
typename
T2
>
class
CPUKernel2Test
:
public
OpKernel
<
float
>
{
public:
void
Compute
(
const
ExecutionContext
&
ctx
)
const
{
std
::
cout
<<
ctx
.
op
().
DebugString
()
<<
std
::
endl
;
cpu_kernel2_run_num
++
;
ASSERT_EQ
(
ctx
.
op
().
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
(
ctx
.
op
().
Output
(
"y"
),
"OUT1"
);
}
};
class
OpKernelTestMultiInputsProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
...
...
@@ -142,6 +161,8 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
GreaterThan
(
0.0
);
AddAttr
<
int
>
(
"kernel_sub_type"
,
"kernels with different implementations."
)
.
SetDefault
(
0
);
AddComment
(
"This is test op"
);
}
};
...
...
@@ -189,9 +210,18 @@ class CPUKernalMultiInputsTest : public OpKernel<float> {
REGISTER_OP_WITHOUT_GRADIENT
(
op_with_kernel
,
paddle
::
framework
::
OpWithKernelTest
,
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
// REGISTER_OP_CPU_KERNEL(op_with_kernel,
// paddle::framework::CPUKernelTest<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
op_with_kernel
,
CPU
,
paddle
::
platform
::
CPUPlace
,
DEFAULT_TYPE
,
0
,
paddle
::
framework
::
CPUKernelTest
<
float
,
float
>
);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
op_with_kernel
,
CPU
,
paddle
::
platform
::
CPUPlace
,
SPECIAL
,
1
,
paddle
::
framework
::
CPUKernel2Test
<
float
,
float
>
);
// test with single input
TEST
(
OpKernel
,
all
)
{
paddle
::
framework
::
InitDevices
(
true
);
...
...
@@ -212,6 +242,16 @@ TEST(OpKernel, all) {
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
0
);
op
->
Run
(
scope
,
cpu_place
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel2_run_num
,
0
);
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"kernel_sub_type"
);
attr
->
set_type
(
paddle
::
framework
::
proto
::
AttrType
::
INT
);
attr
->
set_i
(
1
);
auto
op2
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op2
->
Run
(
scope
,
cpu_place
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel2_run_num
,
1
);
}
REGISTER_OP_WITHOUT_GRADIENT
(
...
...
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
41c28d54
...
...
@@ -491,8 +491,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_KERNEL
(
conv2d
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
conv2d
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
FP32
,
ops
::
kConvMKLDNNFP32
,
ops
::
ConvMKLDNNOpKernel
<
float
>
);
REGISTER_OP_KERNEL
(
conv2d_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
conv2d_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
FP32
,
ops
::
kConvMKLDNNFP32
,
ops
::
ConvMKLDNNGradOpKernel
<
float
>
);
paddle/fluid/operators/conv_op.cc
浏览文件 @
41c28d54
...
...
@@ -74,6 +74,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
ConvOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
int
customized_type_value
=
framework
::
OpKernelType
::
kDefaultCustomizedTypeValue
;
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
...
...
@@ -89,6 +91,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
customized_type_value
=
kConvMKLDNNFP32
;
}
#endif
...
...
@@ -105,7 +108,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
);
library
,
customized_type_value
);
}
void
Conv2DOpMaker
::
Make
()
{
...
...
paddle/fluid/operators/conv_op.h
浏览文件 @
41c28d54
...
...
@@ -27,6 +27,8 @@ namespace paddle {
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
constexpr
int
kConvMKLDNNFP32
=
1
;
constexpr
int
kConvMKLDNNINT8
=
2
;
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录