diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 43ab227a9478707445892c14723801992d0041aa..674159b732b872ac7a6addef0bec5deb12cec4a1 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -91,7 +91,10 @@ struct OpKernelRegistrarFunctor { OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), StringToDataLayout(data_layout), StringToLibraryType(library_type)); - OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE); + OperatorWithKernel::AllOpKernels()[op_type][key] = + [](const framework::ExecutionContext& ctx) { + KERNEL_TYPE().Compute(ctx); + }; constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctor diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 71cd5a39083471af52598cc2a1d4c591d3780624..3cf8e8696d739e3f2894e490161b9fb5b459bc41 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } - kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx)); + kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 1550d5df172f0599e1b42e7f1ccf51ac4dd1e0c3..01d750efbb8aaa35701f6caa7ec103ec21dd529e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase { class OperatorWithKernel : public OperatorBase { public: + using OpKernelFunc = std::function; using OpKernelMap = - std::unordered_map, - OpKernelType::Hash>; + std::unordered_map; OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs)