task_factory.h 3.1 KB
Newer Older
L
lujiale 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef GE_GE_RUNTIME_TASK_TASK_FACTORY_H_
#define GE_GE_RUNTIME_TASK_TASK_FACTORY_H_

#include <functional>
#include <map>
#include <memory>
#include <unordered_map>

#include "common/ge_inner_error_codes.h"
#include "framework/common/debug/ge_log.h"
#include "ge_runtime/task_info.h"

namespace ge {
namespace model_runner {
class Task;
class ModelContext;
using TASK_CREATOR_FUN = std::function<std::shared_ptr<Task>(const ModelContext &, std::shared_ptr<TaskInfo>)>;

class TaskFactory {
 private:
  TaskFactory() {}
  ~TaskFactory() {}
  void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) {
    if (creator_map_.find(type) != creator_map_.end()) {
      GELOGW("Creator type %d already exist", static_cast<int32_t>(type));
    }
    creator_map_[type] = func;
  }

  std::map<TaskInfoType, TASK_CREATOR_FUN> creator_map_;

 public:
  static TaskFactory &GetInstance() {
    static TaskFactory instance;
    return instance;
  }

  std::shared_ptr<Task> Create(const ModelContext &model_context, std::shared_ptr<TaskInfo> &task_info) const {
    if (task_info == nullptr) {
      GELOGE(FAILED, "task_info is null.");
      return nullptr;
    }

    auto iter = creator_map_.find(task_info->type());
    if (iter == creator_map_.end()) {
      GELOGE(FAILED, "Unknow task type %d", static_cast<int32_t>(task_info->type()));
      return nullptr;
    }
    return iter->second(model_context, task_info);
  }

  class Register {
   public:
    Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) {
      GELOGI("regist type %d", static_cast<int32_t>(type));
      TaskFactory::GetInstance().RegisterCreator(type, func);
    }

    ~Register() {}
  };
};

#define REGISTER_TASK(type, task_clazz, task_info_clazz)                                                            \
  TaskFactory::Register g_##task_clazz##_register(                                                                  \
      type,                                                                                                         \
      [](const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) -> std::shared_ptr<Task> {  \
        std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \
        return std::make_shared<task_clazz>(model_context, concrete_task_info);                                     \
      });

}  // namespace model_runner
}  // namespace ge
#endif  // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_