未验证 提交 75ae426a 编写于 作者: Y yuyang18

Merge branch 'feature/change_op_kernel_to_func' into feature/fix_reshape_op_size

......@@ -76,8 +76,9 @@ class OpRegistry {
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;
template <typename PlaceType, typename T, typename KernelType>
inline void RegisterKernelClass(const char* op_type, const char* library_type) {
template <typename PlaceType, typename T, typename Func>
inline void RegisterKernelClass(const char* op_type, const char* library_type,
Func func) {
std::string library(library_type);
std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") {
......@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type) {
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
StringToDataLayout(data_layout),
StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType());
OperatorWithKernel::AllOpKernels()[op_type][key] = func;
}
template <typename PlaceType, size_t I, typename... KernelTypes>
......@@ -96,7 +97,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
});
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func;
......@@ -150,7 +154,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const {
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
});
constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
......
......@@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx));
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
......
......@@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase {
class OperatorWithKernel : public OperatorBase {
public:
using OpKernelFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
OpKernelType::Hash>;
std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册