未验证 提交 75ae426a 编写于 作者: Y yuyang18

Merge branch 'feature/change_op_kernel_to_func' into feature/fix_reshape_op_size

...@@ -76,8 +76,9 @@ class OpRegistry { ...@@ -76,8 +76,9 @@ class OpRegistry {
template <typename PlaceType, bool at_end, size_t I, typename... KernelType> template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor; struct OpKernelRegistrarFunctor;
template <typename PlaceType, typename T, typename KernelType> template <typename PlaceType, typename T, typename Func>
inline void RegisterKernelClass(const char* op_type, const char* library_type) { inline void RegisterKernelClass(const char* op_type, const char* library_type,
Func func) {
std::string library(library_type); std::string library(library_type);
std::string data_layout = "ANYLAYOUT"; std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") { if (library == "MKLDNN") {
...@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type) { ...@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type) {
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
StringToDataLayout(data_layout), StringToDataLayout(data_layout),
StringToLibraryType(library_type)); StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType()); OperatorWithKernel::AllOpKernels()[op_type][key] = func;
} }
template <typename PlaceType, size_t I, typename... KernelTypes> template <typename PlaceType, size_t I, typename... KernelTypes>
...@@ -96,7 +97,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> { ...@@ -96,7 +97,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE; using T = typename KERNEL_TYPE::ELEMENT_TYPE;
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type); RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
});
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...> OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func; func;
...@@ -150,7 +154,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -150,7 +154,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
std::tuple<DataTypeAndKernelType...>>::type; std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type) const {
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type); RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
});
constexpr auto size = constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value; std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
......
...@@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx)); kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered. // there is inplace variable has been transfered.
......
...@@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase { ...@@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase {
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
using OpKernelFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>, std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
OpKernelType::Hash>;
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册