// Copyright 2020 The MACE Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MACE_CORE_REGISTRY_OPS_REGISTRY_H_ #define MACE_CORE_REGISTRY_OPS_REGISTRY_H_ #include #include #include #include #include #include "mace/core/bfloat16.h" #include "mace/core/fp16.h" #include "mace/core/types.h" #include "mace/core/ops/operator.h" #include "mace/core/ops/op_condition_builder.h" #include "mace/core/ops/op_condition_context.h" #include "mace/public/mace.h" #include "mace/proto/mace.pb.h" #include "mace/utils/memory.h" namespace mace { class OpRegistry { public: OpRegistry() = default; virtual ~OpRegistry() = default; MaceStatus Register(const std::string &op_type, const DeviceType device_type, const DataType dt, OpRegistrationInfo::OpCreator creator); MaceStatus Register(const OpConditionBuilder &builder); const std::set AvailableDevices( const std::string &op_type, OpConditionContext *context) const; void GetInOutMemoryTypes( const std::string &op_type, OpConditionContext *context) const; const std::vector InputsDataFormat( const std::string &op_type, OpConditionContext *context) const; std::unique_ptr CreateOperation( OpConstructContext *context, DeviceType device_type) const; template static std::unique_ptr DefaultCreator( OpConstructContext *context) { return make_unique(context); } private: std::unordered_map> registry_; MACE_DISABLE_COPY_AND_ASSIGN(OpRegistry); }; #define MACE_REGISTER_OP(op_registry, op_type, class_name, device, dt) \ op_registry->Register(op_type, \ device, \ DataTypeToEnum
::value, \ OpRegistry::DefaultCreator>) #define MACE_REGISTER_OP_BY_CLASS(\ op_registry, op_type, class_name, device, dt) \ op_registry->Register(op_type, \ device, \ DataTypeToEnum
::value, \ OpRegistry::DefaultCreator) #ifndef MACE_REGISTER_BF16_OP #ifdef MACE_ENABLE_BFLOAT16 #define MACE_REGISTER_BF16_OP(op_registry, op_type, class_name, device) \ MACE_REGISTER_OP(op_registry, op_type, class_name, device, BFloat16) #else #define MACE_REGISTER_BF16_OP(op_registry, op_type, class_name, device) #endif // MACE_ENABLE_BFLOAT16 #endif // MACE_REGISTER_BF16_OP #ifndef MACE_REGISTER_BF16_OP_BY_CLASS #ifdef MACE_ENABLE_BFLOAT16 #define MACE_REGISTER_BF16_OP_BY_CLASS(op_registry, op_type, \ class_name, device) \ MACE_REGISTER_OP_BY_CLASS(op_registry, op_type, \ class_name, device, BFloat16) #else #define MACE_REGISTER_BF16_OP_BY_CLASS(op_registry, op_type, class_name, device) #endif // MACE_ENABLE_BFLOAT16 #endif // MACE_REGISTER_BF16_OP_BY_CLASS #ifndef MACE_REGISTER_FP16_OP #ifdef MACE_ENABLE_FP16 #define MACE_REGISTER_FP16_OP(op_registry, op_type, class_name, device) \ MACE_REGISTER_OP(op_registry, op_type, class_name, device, float16_t) #else #define MACE_REGISTER_FP16_OP(op_registry, op_type, class_name, device) #endif // MACE_ENABLE_FP16 #endif // MACE_REGISTER_FP16_OP #ifndef MACE_REGISTER_FP16_OP_BY_CLASS #ifdef MACE_ENABLE_FP16 #define MACE_REGISTER_FP16_OP_BY_CLASS(op_registry, op_type, \ class_name, device) \ MACE_REGISTER_OP_BY_CLASS(op_registry, op_type, \ class_name, device, float16_t) #else #define MACE_REGISTER_FP16_OP_BY_CLASS(op_registry, op_type, \ class_name, device) #endif // MACE_ENABLE_FP16 #endif // MACE_REGISTER_FP16_OP_BY_CLASS #ifdef MACE_ENABLE_OPENCL #define MACE_REGISTER_GPU_OP(op_registry, op_type, class_name) \ op_registry->Register( \ op_type, \ DeviceType::GPU, \ DT_FLOAT, \ OpRegistry::DefaultCreator>) #else #define MACE_REGISTER_GPU_OP(op_registry, op_type, class_name) #endif #define MACE_REGISTER_OP_CONDITION(op_registry, builder) \ op_registry->Register(builder) } // namespace mace #endif // MACE_CORE_REGISTRY_OPS_REGISTRY_H_