/** * \file src/mge/function_dft.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #if LITE_BUILD_WITH_MGE #include "function_base.h" #include "network_impl.h" #include "network_impl_base.h" #include "tensor_impl.h" namespace lite { #define THROW_FUNC_ERROR(func_name) \ auto msg_info = func_name + " is not aviliable in Dft backend."; \ LITE_THROW(msg_info.c_str()) // the functions used for dft's tensor.cpp are as followed: template <> inline std::shared_ptr call_func>( std::string func_name) { if (func_name == "create_tensor") { return std::make_shared(); } THROW_FUNC_ERROR(func_name); } template <> inline std::shared_ptr call_func>( std::string func_name, LiteDeviceType device_type, bool is_pinned_host) { if (func_name == "create_tensor") { return std::make_shared(device_type, is_pinned_host); } THROW_FUNC_ERROR(func_name); } template <> inline std::shared_ptr call_func>( std::string func_name, int device_id, LiteDeviceType device_type, const Layout layout, bool is_pinned_host) { if (func_name == "create_tensor") { return std::make_shared(device_id, device_type, layout, is_pinned_host); } THROW_FUNC_ERROR(func_name); } template <> inline std::shared_ptr call_func>( std::string func_name, LiteDeviceType device_type, const Layout layout, bool is_pinned_host) { if (func_name == "create_tensor") { return std::make_shared(device_type, layout, is_pinned_host); } THROW_FUNC_ERROR(func_name); } template <> inline std::shared_ptr call_func>( std::string func_name, int device_id, int stream_id, LiteDeviceType device_type, bool is_pinned_host) { if (func_name == "create_tensor") { return std::make_shared(device_id, stream_id, device_type, is_pinned_host); } THROW_FUNC_ERROR(func_name); } // the functions used for dft's network.cpp are as followed: template <> inline std::unique_ptr call_func>( std::string func_name) { if (func_name == "create_network") { return std::make_unique(); } THROW_FUNC_ERROR(func_name); } template <> inline Network::NetworkImplBase* try_call_func( std::string func_name) { if (func_name == "parse_model") { return new NetworkImplDft(); } THROW_FUNC_ERROR(func_name); } #define CALL_FUNC(func_name, ...) \ network_impl->cast_final_safe().func_name(__VA_ARGS__) template <> inline void call_func( std::string func_name, Network::NetworkImplBase* network_impl, size_t num) { if (func_name == "set_cpu_threads_number") { CALL_FUNC(set_cpu_threads_number, num); } else if (func_name == "set_network_algo_workspace_limit") { CALL_FUNC(set_network_algo_workspace_limit, num); } else { THROW_FUNC_ERROR(func_name); } } template <> inline void call_func( std::string func_name, Network::NetworkImplBase* network_impl) { if (func_name == "use_tensorrt") { CALL_FUNC(use_tensorrt); } else if (func_name == "set_cpu_inplace_mode") { CALL_FUNC(set_cpu_inplace_mode); } else { THROW_FUNC_ERROR(func_name); } } template <> inline size_t call_func( std::string func_name, Network::NetworkImplBase* network_impl) { if (func_name == "get_cpu_threads_number") { return CALL_FUNC(get_cpu_threads_number); } THROW_FUNC_ERROR(func_name); } template <> inline bool call_func( std::string func_name, Network::NetworkImplBase* network_impl) { if (func_name == "is_cpu_inplace_mode") { return CALL_FUNC(is_cpu_inplace_mode); } THROW_FUNC_ERROR(func_name); } template <> inline void call_func( std::string func_name, Network::NetworkImplBase* network_impl, ThreadAffinityCallback thread_affinity_callback) { if (func_name == "set_runtime_thread_affinity") { return CALL_FUNC(set_runtime_thread_affinity, std::move(thread_affinity_callback)); } THROW_FUNC_ERROR(func_name); } template <> inline void call_func( std::string func_name, Network::NetworkImplBase* network_impl, LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size, bool binary_equal_between_batch) { if (func_name == "set_network_algo_policy") { return CALL_FUNC(set_network_algo_policy, strategy, shared_batch_size, binary_equal_between_batch); } THROW_FUNC_ERROR(func_name); } template <> inline void call_func( std::string func_name, Network::NetworkImplBase* network_impl, std::shared_ptr user_allocator) { if (func_name == "set_memory_allocator") { return CALL_FUNC(set_memory_allocator, user_allocator); } THROW_FUNC_ERROR(func_name); } template <> inline void call_func( std::string func_name, Network::NetworkImplBase* network_impl, std::string file_name) { if (func_name == "enable_io_txt_dump") { return CALL_FUNC(enable_io_txt_dump, file_name); } else if (func_name == "enable_io_bin_dump") { return CALL_FUNC(enable_io_bin_dump, file_name); } THROW_FUNC_ERROR(func_name); } template <> inline void call_func( std::string func_name, Network::NetworkImplBase* network_impl, Network::NetworkImplBase* src_network_impl) { if (func_name == "share_runtime_memory_with") { CALL_FUNC(share_runtime_memory_with, src_network_impl); } else if (func_name == "shared_weight_with") { CALL_FUNC(shared_weight_with, src_network_impl); } else { THROW_FUNC_ERROR(func_name); } } #undef THROW_FUNC_ERROR } // namespace lite #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}