/** * \file src/mge/common.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "lite_build_config.h" #if LITE_BUILD_WITH_MGE #include "common.h" #include "megdnn/dtype.h" using namespace lite; using namespace mgb; enum class CompressionMethod { NO_COMPRESSION = 0, FLOAT32_STRIDE_FLOAT32_BASE_UINT8_WEIGHTS = 1, FLOAT32_STRIDE_FLOAT32_BASE_UINT16_WEIGHTS = 2, }; void lite::decompressed_tensor_value_loader( void* ptr_, const mgb::TensorLayout& layout, mgb::serialization::InputFile& fin) { uint8_t compress_flag; fin.read(&compress_flag, sizeof(compress_flag)); size_t num_weights = layout.total_nr_elems(); switch (CompressionMethod(compress_flag)) { case CompressionMethod::NO_COMPRESSION: { mgb::serialization::GraphLoadConfig::default_tensor_value_loader( ptr_, layout, fin); break; } case CompressionMethod::FLOAT32_STRIDE_FLOAT32_BASE_UINT8_WEIGHTS: { if (ptr_) { float stride, base; std::vector weights(num_weights); fin.read(&stride, sizeof(stride)); fin.read(&base, sizeof(base)); fin.read(weights.data(), num_weights * sizeof(uint8_t)); auto* ptr = static_cast(ptr_); for (size_t i = 0; i < num_weights; ++i) ptr[i] = stride * weights[i] + base; } else { fin.skip(sizeof(float) * 2 + num_weights * sizeof(uint8_t)); } break; } case CompressionMethod::FLOAT32_STRIDE_FLOAT32_BASE_UINT16_WEIGHTS: { if (ptr_) { float stride, base; std::vector weights(num_weights); fin.read(&stride, sizeof(stride)); fin.read(&base, sizeof(base)); fin.read(weights.data(), num_weights * sizeof(uint16_t)); auto* ptr = static_cast(ptr_); for (size_t i = 0; i < num_weights; ++i) ptr[i] = stride * weights[i] + base; } else { fin.skip(sizeof(float) * 2 + num_weights * sizeof(uint16_t)); } break; } default: LITE_THROW("Unexpected compression method"); } } LTensorLayout lite::to_impl_layout(const Layout& layout) { mgb::TensorLayout mge_layout; mge_layout.ndim = layout.ndim; LITE_ASSERT(layout.ndim < TensorShape::MAX_NDIM, "lite layout ndim is to large"); for (size_t i = 0; i < layout.ndim; i++) { mge_layout.shape[i] = layout.shapes[i]; } mge_layout.init_contiguous_stride(); switch (layout.data_type) { case LiteDataType::LITE_FLOAT: mge_layout.dtype = mgb::dtype::Float32(); break; case LiteDataType::LITE_HALF: mge_layout.dtype = mgb::dtype::Float16(); break; case LiteDataType::LITE_INT: mge_layout.dtype = mgb::dtype::Int32(); break; case LiteDataType::LITE_INT8: mge_layout.dtype = mgb::dtype::Int8(); break; case LiteDataType::LITE_UINT8: mge_layout.dtype = mgb::dtype::Uint8(); break; case LiteDataType::LITE_INT16: mge_layout.dtype = mgb::dtype::Int16(); break; default: LITE_THROW(mgb::ssprintf("unsupport dtype in lite enum id is %d.", static_cast(layout.data_type))); } return mge_layout; } Layout lite::to_lite_layout(const LTensorLayout& mge_layout) { Layout layout; if (!mge_layout.dtype.valid()) { return layout; } layout.ndim = mge_layout.ndim; LITE_ASSERT(layout.ndim < layout.MAXDIM, "tensor layout ndim is to large"); for (size_t i = 0; i < layout.ndim; i++) { layout.shapes[i] = mge_layout.shape[i]; } switch (mge_layout.dtype.enumv()) { case mgb::DTypeEnum::Float32: layout.data_type = LiteDataType::LITE_FLOAT; break; case mgb::DTypeEnum::Float16: layout.data_type = LiteDataType::LITE_HALF; break; case mgb::DTypeEnum::Int32: layout.data_type = LiteDataType::LITE_INT; break; case mgb::DTypeEnum::Int16: layout.data_type = LiteDataType::LITE_INT16; break; case mgb::DTypeEnum::Int8: layout.data_type = LiteDataType::LITE_INT8; break; case mgb::DTypeEnum::Uint8: layout.data_type = LiteDataType::LITE_UINT8; break; default: LITE_THROW(mgb::ssprintf("unsupport dtype in lite : %s.", mge_layout.to_string().c_str())); } return layout; } mgb::CompNode::Locator lite::to_compnode_locator(const LiteDeviceType& device) { mgb::CompNode::Locator loc; switch (device) { case LiteDeviceType::LITE_CPU: loc.type = mgb::CompNode::DeviceType::CPU; break; case LiteDeviceType::LITE_CUDA: loc.type = mgb::CompNode::DeviceType::CUDA; break; case LiteDeviceType::LITE_ATLAS: loc.type = mgb::CompNode::DeviceType::ATLAS; break; case LiteDeviceType::LITE_DEVICE_DEFAULT: loc.type = mgb::CompNode::DeviceType::UNSPEC; break; default: LITE_THROW( ssprintf("lite unsupported compnode type: enum value: %d.", (int)(device))); } return loc; } LiteDeviceType lite::get_device_from_locator( const mgb::CompNode::Locator& locator) { switch (locator.type) { case mgb::CompNode::DeviceType::CPU: case mgb::CompNode::DeviceType::MULTITHREAD: return LiteDeviceType::LITE_CPU; case mgb::CompNode::DeviceType::CUDA: return LiteDeviceType::LITE_CUDA; case mgb::CompNode::DeviceType::ATLAS: return LiteDeviceType::LITE_ATLAS; case mgb::CompNode::DeviceType::UNSPEC: return LiteDeviceType::LITE_DEVICE_DEFAULT; default: LITE_THROW( ssprintf("lite unsupported compnode type: enum value: %d.", (int)(locator.type))); } } #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}