/** * \file src/network.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "lite/network.h" #include "function_base.h" #include "network_impl_base.h" #include "parse_info/parse_info_base.h" #include "parse_model/model_parser.h" #include "type_info.h" #if LITE_BUILD_WITH_MGE #include "mge/function_dft.h" #include "mge/network_impl.h" #endif #include #include using namespace lite; /** * \brief Construct the new work implement * the order must be : * 1. creeat the implement * 2. config and load * 3. set_io */ Network::Network(const Config& config, const NetworkIO& network_io) { LITE_ERROR_HANDLER_BEGIN m_config = config; m_network_io = network_io; if (config.backend == LiteBackend::LITE_DEFAULT) { m_impl = call_func>( "create_network"); } else if (config.backend == LiteBackend::LITE_RK_NPU) { m_impl = call_func>( "create_network"); } m_impl->set_config(config); m_impl->set_io(network_io); LITE_ERROR_HANDLER_END } Network::Network(const NetworkIO& network_io, const Config& config) { LITE_ERROR_HANDLER_BEGIN m_config = config; m_network_io = network_io; if (config.backend == LiteBackend::LITE_DEFAULT) { m_impl = call_func>( "create_network"); } else if (config.backend == LiteBackend::LITE_RK_NPU) { m_impl = call_func>( "create_network"); } m_impl->set_config(config); m_impl->set_io(network_io); LITE_ERROR_HANDLER_END } void Network::load_model(void* model_mem, size_t size) { LITE_ERROR_HANDLER_BEGIN LITE_CHECK_NON_NULL_POINTER(m_impl); //! this model_mem is managed by user std::shared_ptr model{model_mem, [](void*) {}}; prase_model(model, size); LITE_ERROR_HANDLER_END } void Network::load_model(std::string model_path) { LITE_ERROR_HANDLER_BEGIN LITE_CHECK_NON_NULL_POINTER(m_impl); FILE* fin = fopen(model_path.c_str(), "rb"); LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); fseek(fin, 0, SEEK_END); size_t size = ftell(fin); fseek(fin, 0, SEEK_SET); void* ptr = malloc(size); std::shared_ptr buf{ptr, ::free}; auto nr = fread(buf.get(), 1, size, fin); LITE_ASSERT(nr == size); fclose(fin); prase_model(buf, size); LITE_ERROR_HANDLER_END } void Network::prase_model(std::shared_ptr model_data, size_t size) { std::unordered_map separate_config_map; ModelParser model_parser(model_data, size); //! parse the model info if (model_parser.parse_model_info(m_config, m_network_io, separate_config_map, m_extra_info)) { if (m_config.backend == LiteBackend::LITE_DEFAULT && m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) { m_impl.reset(try_call_func( "parse_model")); } else if (m_config.backend == LiteBackend::LITE_RK_NPU && m_impl->get_backend_type() != LiteBackend::LITE_RK_NPU) { m_impl.reset(try_call_func( "parse_model")); } m_impl->set_config(m_config); m_impl->set_io(m_network_io); } //! decryption the model size_t model_length; auto&& model_shared_ptr = model_parser.parse_model(model_length, m_config); m_impl->load_model(model_shared_ptr, model_length, separate_config_map); m_loaded = true; update_from_implement(); } Network::~Network() = default; void Network::update_from_implement() { m_config.device_type = m_impl->get_device_type(); } void Network::compute_only_configured_output() { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(!m_loaded, "compute_only_configured_output should be used before model " "loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); return m_impl->compute_only_configured_output(); LITE_ERROR_HANDLER_END } std::shared_ptr Network::get_io_tensor(std::string name, LiteTensorPhase phase) { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); return m_impl->get_io_tensor(name, phase); LITE_ERROR_HANDLER_END } std::shared_ptr Network::get_input_tensor(size_t index) { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "get_input_tensor should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); return m_impl->get_input_tensor(index); LITE_ERROR_HANDLER_END } std::shared_ptr Network::get_output_tensor(size_t index) { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "get_output_tensor should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); return m_impl->get_output_tensor(index); LITE_ERROR_HANDLER_END } Network& Network::set_async_callback(const AsyncCallback& callback) { LITE_ERROR_HANDLER_BEGIN LITE_CHECK_NON_NULL_POINTER(m_impl); m_impl->set_async_callback(std::move(callback)); return *this; LITE_ERROR_HANDLER_END } Network& Network::set_start_callback(const StartCallback& callback) { LITE_ERROR_HANDLER_BEGIN LITE_CHECK_NON_NULL_POINTER(m_impl); m_impl->set_start_callback(std::move(callback)); return *this; LITE_ERROR_HANDLER_END } Network& Network::set_finish_callback(const FinishCallback& callback) { LITE_ERROR_HANDLER_BEGIN LITE_CHECK_NON_NULL_POINTER(m_impl); m_impl->set_finish_callback(std::move(callback)); return *this; LITE_ERROR_HANDLER_END } Network& Network::set_device_id(int device_id) { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(!m_loaded, "set_device_id should be used before model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); m_impl->set_device_id(device_id); return *this; LITE_ERROR_HANDLER_END } Network& Network::set_stream_id(int stream_id) { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(!m_loaded, "set_stream_id should be used before model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); m_impl->set_stream_id(stream_id); return *this; LITE_ERROR_HANDLER_END } void Network::forward() { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "forward should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl.get()); m_impl->forward(); LITE_ERROR_HANDLER_END } void Network::wait() { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "wait should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); m_impl->wait(); LITE_ERROR_HANDLER_END } std::string Network::get_input_name(size_t index) const { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "get_input_name should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); return m_impl->get_input_name(index); LITE_ERROR_HANDLER_END } std::string Network::get_output_name(size_t index) const { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "get_output_name should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); return m_impl->get_output_name(index); LITE_ERROR_HANDLER_END } std::vector Network::get_all_input_name() const { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "get_all_input_name should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); auto all_input_name = m_impl->get_all_input_name(); std::vector all_names; for (auto& name : all_input_name) { all_names.push_back(name); } return all_names; LITE_ERROR_HANDLER_END } std::vector Network::get_all_output_name() const { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "get_all_output_name should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); auto all_output_name = m_impl->get_all_output_name(); std::vector all_names; for (auto& name : all_output_name) { all_names.push_back(name); } return all_names; LITE_ERROR_HANDLER_END } int Network::get_device_id() const { LITE_ERROR_HANDLER_BEGIN LITE_CHECK_NON_NULL_POINTER(m_impl); return m_impl->get_device_id(); LITE_ERROR_HANDLER_END } int Network::get_stream_id() const { LITE_ERROR_HANDLER_BEGIN LITE_CHECK_NON_NULL_POINTER(m_impl); return m_impl->get_stream_id(); LITE_ERROR_HANDLER_END } void Network::enable_profile_performance(std::string profile_file_path) { LITE_ERROR_HANDLER_BEGIN m_impl->enable_profile_performance(profile_file_path); LITE_ERROR_HANDLER_END } const std::string& Network::get_model_extra_info() { LITE_ERROR_HANDLER_BEGIN return m_extra_info; LITE_ERROR_HANDLER_END } LiteDeviceType Network::get_device_type() const { LITE_ERROR_HANDLER_BEGIN return m_impl->get_device_type(); LITE_ERROR_HANDLER_END } /*********************** MGE special network function ***************/ void Runtime::set_cpu_threads_number(std::shared_ptr network, size_t nr_threads) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { LITE_ASSERT( !NetworkHelper::loaded(network), "set_cpu_threads_number should be used before model loaded."); call_func("set_cpu_threads_number", network_impl, nr_threads); return; } LITE_THROW("set_cpu_threads_number is not aviliable in the backend."); LITE_ERROR_HANDLER_END } void Runtime::use_tensorrt(std::shared_ptr network) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { LITE_ASSERT(!NetworkHelper::loaded(network), "use_tensorrt should be used before model loaded."); call_func("use_tensorrt", network_impl); return; } LITE_THROW("use_tensorrt is not aviliable in the backend."); LITE_ERROR_HANDLER_END } size_t Runtime::get_cpu_threads_number(const std::shared_ptr network) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { return call_func("get_cpu_threads_number", network_impl); } LITE_THROW("get_cpu_threads_number is not aviliable in the backend."); LITE_ERROR_HANDLER_END } void Runtime::set_runtime_thread_affinity( std::shared_ptr network, const ThreadAffinityCallback& thread_affinity_callback) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { LITE_ASSERT(NetworkHelper::loaded(network), "set_runtime_thread_affinity should be used after model " "loaded."); call_func("set_runtime_thread_affinity", network_impl, thread_affinity_callback); return; } LITE_THROW("set_runtime_thread_affinity is not aviliable in the backend."); LITE_ERROR_HANDLER_END } void Runtime::set_cpu_inplace_mode(std::shared_ptr network) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { LITE_ASSERT(!NetworkHelper::loaded(network), "set_cpu_inplace_mode should be used before model loaded."); call_func("set_cpu_inplace_mode", network_impl); return; } LITE_THROW("set_cpu_inplace_mode is not aviliable in the backend."); LITE_ERROR_HANDLER_END } bool Runtime::is_cpu_inplace_mode(const std::shared_ptr network) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { return call_func("is_cpu_inplace_mode", network_impl); } LITE_THROW("is_cpu_inplace_mode is not aviliable in the backend."); LITE_ERROR_HANDLER_END } //! set opr algorithm selection strategy in the network void Runtime::set_network_algo_policy(std::shared_ptr network, LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size, bool binary_equal_between_batch) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { call_func("set_network_algo_policy", network_impl, strategy, shared_batch_size, binary_equal_between_batch); return; } LITE_THROW("set_network_algo_policy is not aviliable in the backend."); LITE_ERROR_HANDLER_END } //! set opr algorithm selection strategy in the network void Runtime::set_network_algo_workspace_limit(std::shared_ptr network, size_t workspace_limit) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { LITE_ASSERT(NetworkHelper::loaded(network), "set_network_algo_policy should be used after model " "loaded."); call_func("set_network_algo_workspace_limit", network_impl, workspace_limit); return; } LITE_THROW( "set_network_algo_workspace_limit is not aviliable in the " "backend."); LITE_ERROR_HANDLER_END } //! set the network memroy allocator, the allocator is defined by user void Runtime::set_memory_allocator(std::shared_ptr network, std::shared_ptr user_allocator) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { LITE_ASSERT(!NetworkHelper::loaded(network), "set_memory_allocator should be used before model loaded."); call_func("set_memory_allocator", network_impl, user_allocator); return; } LITE_THROW("set_memory_allocator is not aviliable in the backend."); LITE_ERROR_HANDLER_END } void Runtime::share_runtime_memory_with(std::shared_ptr dst_network, std::shared_ptr src_network) { LITE_ERROR_HANDLER_BEGIN auto network_impl_dst = NetworkHelper::implement(dst_network); if (network_impl_dst->get_backend_type() == LiteBackend::LITE_DEFAULT) { LITE_ASSERT(!NetworkHelper::loaded(dst_network), "share_runtime_memory_with should be used before model " "loaded."); call_func("share_runtime_memory_with", network_impl_dst, NetworkHelper::implement(src_network)); return; } LITE_THROW("share_runtime_memory_with is not aviliable in the backend."); LITE_ERROR_HANDLER_END } void Runtime::enable_io_txt_dump(std::shared_ptr network, std::string io_txt_out_file) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { call_func("enable_io_txt_dump", network_impl, io_txt_out_file); return; } LITE_THROW("enable_io_txt_dump is not aviliable in the backend."); LITE_ERROR_HANDLER_END } void Runtime::enable_io_bin_dump(std::shared_ptr network, std::string io_bin_out_dir) { LITE_ERROR_HANDLER_BEGIN auto network_impl = NetworkHelper::implement(network); if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { call_func("enable_io_bin_dump", network_impl, io_bin_out_dir); return; } LITE_THROW("enable_io_bin_dump is not aviliable in the backend."); LITE_ERROR_HANDLER_END } void Runtime::shared_weight_with_network( std::shared_ptr dst_network, const std::shared_ptr src_network) { LITE_ERROR_HANDLER_BEGIN auto network_impl_dst = NetworkHelper::implement(dst_network); if (network_impl_dst->get_backend_type() == LiteBackend::LITE_DEFAULT) { LITE_ASSERT(NetworkHelper::loaded(src_network), "shared_weight_with_network should be used after the src " "network " "loaded."); auto src_implment = NetworkHelper::implement(src_network); call_func("shared_weight_with", network_impl_dst, src_implment); NetworkHelper::loaded(dst_network, true); return; } LITE_THROW("shared_weight_with_network is not aviliable in the backend."); LITE_ERROR_HANDLER_END } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}