提交 41c28d54 编写于 作者: X Xin Pan

allow customize kernel selection

test=develop
上级 0e3048db
...@@ -166,6 +166,8 @@ function(op_library TARGET) ...@@ -166,6 +166,8 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator # Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif() endif()
......
...@@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) ...@@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache) shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
...@@ -191,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry ...@@ -191,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(tuple_test SRCS tuple_test.cc ) cc_test(tuple_test SRCS tuple_test.cc )
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_kernel_type.h"
namespace paddle {
namespace framework {
size_t OpKernelType::Hash::operator()(const OpKernelType& key) const {
int cur_loc = 0;
int place = key.place_.which();
cur_loc += OpKernelType::kPlaceBits;
int data_type = static_cast<int>(key.data_type_) << cur_loc;
cur_loc += OpKernelType::kPrimaryDTypeBits;
int data_layout = static_cast<int>(key.data_layout_) << cur_loc;
cur_loc += OpKernelType::kLayoutBits;
int library_type = static_cast<int>(key.library_type_) << cur_loc;
cur_loc += OpKernelType::kLibBits;
int customized_value = key.customized_type_value_;
PADDLE_ENFORCE(customized_value < (1 << OpKernelType::kCustomizeBits));
customized_value = customized_value << cur_loc;
cur_loc += OpKernelType::kCustomizeBits;
PADDLE_ENFORCE(cur_loc < 64);
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type +
customized_value);
}
bool OpKernelType::operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_ &&
customized_type_value_ == o.customized_type_value_;
}
} // namespace framework
} // namespace paddle
...@@ -24,54 +24,55 @@ limitations under the License. */ ...@@ -24,54 +24,55 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct OpKernelType { class OpKernelType {
struct Hash { public:
size_t operator()(const OpKernelType& key) const { constexpr static int kDefaultCustomizedTypeValue = 0;
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT;
int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2);
int library_type = static_cast<int>(key.library_type_)
<< (LEFT_SHIFT * 3);
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type);
}
};
// place, data_type, library_type kinds less than 2^8 // In total should be smaller than 64.
constexpr static int LEFT_SHIFT = 8; constexpr static int kPlaceBits = 4;
constexpr static int kPrimaryDTypeBits = 8;
proto::VarType::Type data_type_; constexpr static int kLayoutBits = 4;
DataLayout data_layout_; constexpr static int kLibBits = 4;
platform::Place place_; constexpr static int kCustomizeBits = 4;
LibraryType library_type_;
OpKernelType(proto::VarType::Type data_type, platform::Place place, OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain,
int customized_type_value = kDefaultCustomizedTypeValue)
: data_type_(data_type), : data_type_(data_type),
data_layout_(data_layout), data_layout_(data_layout),
place_(place), place_(place),
library_type_(library_type) {} library_type_(library_type),
customized_type_value_(customized_type_value) {}
OpKernelType(proto::VarType::Type data_type, OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain,
int customized_type_value = kDefaultCustomizedTypeValue)
: data_type_(data_type), : data_type_(data_type),
data_layout_(data_layout), data_layout_(data_layout),
place_(dev_ctx.GetPlace()), place_(dev_ctx.GetPlace()),
library_type_(library_type) {} library_type_(library_type),
customized_type_value_(customized_type_value) {}
virtual ~OpKernelType() {}
struct Hash {
size_t operator()(const OpKernelType& key) const;
};
size_t hash_key() const { return Hash()(*this); } size_t hash_key() const { return Hash()(*this); }
bool operator==(const OpKernelType& o) const { bool operator==(const OpKernelType& o) const;
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
bool operator!=(const OpKernelType& o) const { return !(*this == o); } bool operator!=(const OpKernelType& o) const { return !(*this == o); }
proto::VarType::Type data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;
int customized_type_value_;
}; };
inline std::ostream& operator<<(std::ostream& os, inline std::ostream& operator<<(std::ostream& os,
......
...@@ -35,6 +35,7 @@ limitations under the License. */ ...@@ -35,6 +35,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Registrar { class Registrar {
public: public:
// In our design, various kinds of classes, e.g., operators and kernels, // In our design, various kinds of classes, e.g., operators and kernels,
...@@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor; ...@@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor;
template <typename PlaceType, typename T, typename Func> template <typename PlaceType, typename T, typename Func>
inline void RegisterKernelClass(const char* op_type, const char* library_type, inline void RegisterKernelClass(const char* op_type, const char* library_type,
Func func) { int customized_type_value, Func func) {
std::string library(library_type); std::string library(library_type);
std::string data_layout = "ANYLAYOUT"; std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") { if (library == "MKLDNN") {
...@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type, ...@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type,
} }
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
StringToDataLayout(data_layout), StringToDataLayout(data_layout),
StringToLibraryType(library_type)); StringToLibraryType(library_type), customized_type_value);
OperatorWithKernel::AllOpKernels()[op_type][key] = func; OperatorWithKernel::AllOpKernels()[op_type][key] = func;
} }
...@@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> { ...@@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE = using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type; typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE; using T = typename KERNEL_TYPE::ELEMENT_TYPE;
RegisterKernelClass<PlaceType, T>( RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) { op_type, library_type, customized_type_value,
[](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx); KERNEL_TYPE().Compute(ctx);
}); });
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...> OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
template <typename PlaceType, size_t I, typename... KernelType> template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> { struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type, const char* library_type) const {} void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {}
}; };
// User can register many kernel in one place. The data type could be // User can register many kernel in one place. The data type could be
...@@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> { ...@@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
template <typename PlaceType, typename... KernelType> template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar { class OpKernelRegistrar : public Registrar {
public: public:
explicit OpKernelRegistrar(const char* op_type, const char* library_type) { explicit OpKernelRegistrar(const char* op_type, const char* library_type,
int customized_type_value) {
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func; OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
...@@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx; ...@@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx;
template <typename PlaceType, typename... DataTypeAndKernelType> template <typename PlaceType, typename... DataTypeAndKernelType>
class OpKernelRegistrarEx : public Registrar { class OpKernelRegistrarEx : public Registrar {
public: public:
explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) { explicit OpKernelRegistrarEx(const char* op_type, const char* library_type,
int customized_type_value) {
OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...> OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType> template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, true, I, struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
DataTypeAndKernelType...> { DataTypeAndKernelType...> {
void operator()(const char* op_type, const char* library_type) const {} void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {}
}; };
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType> template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
...@@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
typename std::tuple_element<I, typename std::tuple_element<I,
std::tuple<DataTypeAndKernelType...>>::type; std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type,
RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor()); int customized_type_value) const {
RegisterKernelClass<PlaceType, T>(op_type, library_type,
customized_type_value, Functor());
constexpr auto size = constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value; std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2, OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
DataTypeAndKernelType...> DataTypeAndKernelType...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
// clang-format off
/** /**
* check if MACRO is used in GLOBAL NAMESPACE. * check if MACRO is used in GLOBAL NAMESPACE.
*/ */
...@@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
*/ */
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \ #define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type, \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ place_class, customized_name, \
__reg_op_kernel_##op_type##_##library_type##__, \ customized_type_value, ...) \
"REGISTER_OP_KERNEL must be called in global namespace"); \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \ __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ "REGISTER_OP_KERNEL must be called in " \
#library_type); \ "global namespace"); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \ static ::paddle::framework::OpKernelRegistrar<place_class, \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ __VA_ARGS__> \
return 0; \ __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
#op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \
} }
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( \
op_type, library_type, place_class, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \ #define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ customized_name, \
__reg_op_kernel_##op_type##_##library_type##__, \ customized_type_value, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \ ...) \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
#library_type); \ "REGISTER_OP_KERNEL_EX must be called in " \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \ "global namespace"); \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ static ::paddle::framework::OpKernelRegistrarEx<place_class, \
return 0; \ __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
#op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \
} }
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \ REGISTER_OP_KERNEL_EX( \
__VA_ARGS__) op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL_EX( \
op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
/** /**
* Macro to mark what Operator and Kernel * Macro to mark what Operator and Kernel
...@@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
extern int TouchOpRegistrar_##op_type(); \ extern int TouchOpRegistrar_##op_type(); \
UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type() UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \ #define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ LIBRARY_TYPE, \
__use_op_kernel_##op_type##_##LIBRARY_TYPE##__, \ customized_name) \
"USE_OP_DEVICE_KERNEL must be in global namespace"); \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
extern int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE(); \ __use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##__, \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_ = \ "USE_OP_DEVICE_KERNEL must be in global namespace"); \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE() extern int \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##DEFAULT_TYPE##_ = /* NOLINT */ \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
// TODO(fengjiayi): The following macros // TODO(fengjiayi): The following macros
// seems ugly, do we have better method? // seems ugly, do we have better method?
...@@ -280,6 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -280,6 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type) USE_OP_KERNEL(op_type)
// clang-format off
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -50,6 +50,8 @@ class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -50,6 +50,8 @@ class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
AddOutput("output", "output of test op"); AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op"); AddAttr<float>("scale", "scale of cosine op");
AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
.SetDefault(0);
AddComment("This is test op"); AddComment("This is test op");
} }
}; };
...@@ -103,11 +105,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -103,11 +105,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.GreaterThan(0.0); .GreaterThan(0.0);
AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
.SetDefault(0);
AddComment("This is test op"); AddComment("This is test op");
} }
}; };
static int cpu_kernel_run_num = 0; static int cpu_kernel_run_num = 0;
static int cpu_kernel2_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
public: public:
...@@ -117,7 +122,10 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -117,7 +122,10 @@ class OpWithKernelTest : public OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
return OpKernelType(proto::VarType::FP32, ctx.GetPlace()); int sub_type = ctx.Attr<int>("kernel_sub_type");
return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kPlain, sub_type);
} }
}; };
...@@ -132,6 +140,17 @@ class CPUKernelTest : public OpKernel<float> { ...@@ -132,6 +140,17 @@ class CPUKernelTest : public OpKernel<float> {
} }
}; };
template <typename T1, typename T2>
class CPUKernel2Test : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl;
cpu_kernel2_run_num++;
ASSERT_EQ(ctx.op().Input("x"), "IN1");
ASSERT_EQ(ctx.op().Output("y"), "OUT1");
}
};
class OpKernelTestMultiInputsProtoAndCheckerMaker class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker { : public OpProtoAndCheckerMaker {
public: public:
...@@ -142,6 +161,8 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker ...@@ -142,6 +161,8 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.GreaterThan(0.0); .GreaterThan(0.0);
AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
.SetDefault(0);
AddComment("This is test op"); AddComment("This is test op");
} }
}; };
...@@ -189,8 +210,17 @@ class CPUKernalMultiInputsTest : public OpKernel<float> { ...@@ -189,8 +210,17 @@ class CPUKernalMultiInputsTest : public OpKernel<float> {
REGISTER_OP_WITHOUT_GRADIENT( REGISTER_OP_WITHOUT_GRADIENT(
op_with_kernel, paddle::framework::OpWithKernelTest, op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker); paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>); // REGISTER_OP_CPU_KERNEL(op_with_kernel,
// paddle::framework::CPUKernelTest<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
op_with_kernel, CPU, paddle::platform::CPUPlace, DEFAULT_TYPE, 0,
paddle::framework::CPUKernelTest<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
op_with_kernel, CPU, paddle::platform::CPUPlace, SPECIAL, 1,
paddle::framework::CPUKernel2Test<float, float>);
// test with single input // test with single input
TEST(OpKernel, all) { TEST(OpKernel, all) {
...@@ -212,6 +242,16 @@ TEST(OpKernel, all) { ...@@ -212,6 +242,16 @@ TEST(OpKernel, all) {
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);
attr = op_desc.mutable_attrs()->Add();
attr->set_name("kernel_sub_type");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(1);
auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
op2->Run(scope, cpu_place);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
} }
REGISTER_OP_WITHOUT_GRADIENT( REGISTER_OP_WITHOUT_GRADIENT(
......
...@@ -491,8 +491,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -491,8 +491,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
ops::ConvMKLDNNOpKernel<float>); ::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::ConvMKLDNNOpKernel<float>);
ops::ConvMKLDNNGradOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float>);
...@@ -74,6 +74,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -74,6 +74,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
...@@ -89,6 +91,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -89,6 +91,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
} }
#endif #endif
...@@ -105,7 +108,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -105,7 +108,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library); library, customized_type_value);
} }
void Conv2DOpMaker::Make() { void Conv2DOpMaker::Make() {
......
...@@ -27,6 +27,8 @@ namespace paddle { ...@@ -27,6 +27,8 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
constexpr int kConvMKLDNNFP32 = 1;
constexpr int kConvMKLDNNINT8 = 2;
// Base convolution operator definations for other conv // Base convolution operator definations for other conv
// like operators to reuse the implementation. // like operators to reuse the implementation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册