From 3b9b87809db0ddc849ce0b629b142434e7de1c32 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 24 Mar 2021 14:06:27 +0800 Subject: [PATCH] refactor(dnn): refactor lowbit tensor format GitOrigin-RevId: b646dc085b171589d5dec24650ccf7148f8a03af --- dnn/include/megdnn/basic_types.h | 4 + dnn/include/megdnn/tensor_format.h | 39 +-- dnn/src/common/basic_types.cpp | 13 +- dnn/src/common/convolution.cpp | 3 +- dnn/src/common/tensor_format.cpp | 101 +++--- .../conv_bias/cudnn_conv_bias_activation.cpp | 3 + dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp | 23 +- dnn/src/cuda/conv_bias/helper.cpp | 4 + .../sass_implicit_gemm_int4_nchw64_imma.cpp | 11 +- dnn/test/common/test_basic_types.cpp | 8 +- src/core/impl/graph/operator_node.cpp | 13 +- src/core/impl/graph/var_node_mem_mgr.cpp | 25 +- src/core/impl/tensor.cpp | 29 +- src/core/impl/utils/persistent_cache.cpp | 3 +- src/core/include/megbrain/tensor.h | 5 +- src/gopt/impl/inference.cpp | 6 +- src/opr/impl/basic_arith.cpp | 9 +- src/opr/impl/io.cpp | 7 +- src/opr/impl/search_policy/algo_chooser.cpp | 12 +- src/opr/impl/search_policy/profiler.cpp | 4 +- src/opr/test/dnn/convolution.cpp | 302 ++++++++++++++++++ 21 files changed, 497 insertions(+), 127 deletions(-) diff --git a/dnn/include/megdnn/basic_types.h b/dnn/include/megdnn/basic_types.h index c2902613..a01cea1c 100644 --- a/dnn/include/megdnn/basic_types.h +++ b/dnn/include/megdnn/basic_types.h @@ -170,6 +170,7 @@ struct TensorLayout : public TensorShape { #if MEGDNN_CC_HOST Format(); + Format(DType dtype); const ImplBase* impl() const { return m_impl; } @@ -198,6 +199,9 @@ struct TensorLayout : public TensorShape { //! whether this is the default tensor format bool is_default() const; + //! whether this is the lowbit aligned to bytes tensor format + bool is_lowbit_aligned() const; + bool operator==(Format rhs) const { return m_impl == rhs.m_impl; } bool operator!=(Format rhs) const { return m_impl != rhs.m_impl; } #endif diff --git a/dnn/include/megdnn/tensor_format.h b/dnn/include/megdnn/tensor_format.h index f8dd3746..4a0bb570 100644 --- a/dnn/include/megdnn/tensor_format.h +++ b/dnn/include/megdnn/tensor_format.h @@ -20,7 +20,7 @@ namespace megdnn { enum class TensorFormat::Type { DEFAULT = 0, //!< see DefaultTensorFormat IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat - FOURBITS_ALIGNED_TO_BYTE = 2, //!< + LOWBITS_ALIGNED_TO_BYTE = 2, //!< }; class TensorFormat::ImplBase { @@ -205,21 +205,23 @@ using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; /*! * \brief used for tensors storing lowbit data * - * \p SIZE_NBITS is the size in bits of element of the tensor. - * + * \param m_size_nbits size in bits of elements in the tensor + * \param m_align_size_in_bits aligned size in bits + * \param m_align_size_in_elements aligned size in elements */ -template -class LowbitsTensorFormatBase : public TensorFormat::ImplBase { - static constexpr size_t SIZE_NBITS = SIZE_NBITS_; - size_t m_align_size_in_bits, m_align_size_in_elements; +class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase { + size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; protected: //? - LowbitsTensorFormatBase(Type type, size_t align_size_in_bits); + LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits, + size_t align_size_in_bits); - virtual ~LowbitsTensorFormatBase() = default; + virtual ~LowbitsAlignedTensorFormatBase() = default; public: size_t align_size_in_bits() const { return m_align_size_in_bits; } + + size_t size_nbits() const { return m_size_nbits; } std::string to_string() const override; @@ -240,10 +242,10 @@ public: const TensorLayout& layout) const override; protected: struct SerializePack { + uint8_t size_nbits; uint8_t align_size_in_bits; }; }; -using FourBitsAlignedToBytesTensorFormatBase = LowbitsTensorFormatBase<4>; } // namespace detail /*! @@ -296,19 +298,20 @@ private: * \brief Tensor for storing 4bit data that requires stride corresponding to * non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte */ -class FourBitsAlignedToBytesTensorFormat final - : public detail::FourBitsAlignedToBytesTensorFormatBase { +class LowbitsAlignedToBytesTensorFormat final + : public detail::LowbitsAlignedTensorFormatBase { public: - static constexpr Type TYPE = Type::FOURBITS_ALIGNED_TO_BYTE; + static constexpr Type TYPE = Type::LOWBITS_ALIGNED_TO_BYTE; + static constexpr size_t BYTE_IN_BITS = 8; - static TensorFormat make(size_t align_size_in_bits); + static TensorFormat make(size_t size_nbits); static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); static bool is_valid_layout(const TensorLayout& layout) { if (layout.format.type() == TYPE) { - layout.format.as_impl() + layout.format.as_impl() .assert_valid(layout); return true; } @@ -316,9 +319,9 @@ public: } private: - FourBitsAlignedToBytesTensorFormat(size_t align_size_in_bits) - : detail::FourBitsAlignedToBytesTensorFormatBase( - TYPE, align_size_in_bits) {} + LowbitsAlignedToBytesTensorFormat(size_t size_nbits) + : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, + BYTE_IN_BITS) {} }; } // namespace megdnn diff --git a/dnn/src/common/basic_types.cpp b/dnn/src/common/basic_types.cpp index 78700e89..d2e96c93 100644 --- a/dnn/src/common/basic_types.cpp +++ b/dnn/src/common/basic_types.cpp @@ -195,21 +195,14 @@ bool TensorShape::is_empty() const { /* ===================== TensorLayout ===================== */ TensorLayout::TensorLayout() = default; -TensorLayout::TensorLayout(DType dtype_) : dtype{dtype_} {} +TensorLayout::TensorLayout(DType dtype_) + : dtype{dtype_}, format{Format(dtype)} {} TensorLayout::TensorLayout(DType dtype_, Format format_) : dtype{dtype_}, format{format_} {} TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) - : TensorShape(shape), dtype{dtype} { - if (dtype.low_bit() == 4_z) { - format = FourBitsAlignedToBytesTensorFormat::make(8_z); - } else { - megdnn_assert(!dtype.is_low_bit(), "Unsupported data type(%s)", - dtype.name()); - format = DefaultTensorFormat::make(); - } -} + : TensorLayout(shape, dtype, Format(dtype)) {} TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, TensorFormat format_) diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index b645cfbf..d7d3532b 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -722,7 +722,7 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, megdnn_assert(src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && src[src.ndim - 1] == 64 && - filter[filter.ndim - 1] == 4, + filter[filter.ndim - 1] == 64, "NCHW64 require src and filter's ndim is 5 or 6, and " "last shape is 64 but got src %s, filter %s", src.to_string().c_str(), filter.to_string().c_str()); @@ -754,7 +754,6 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], cflt.stride[i], cflt.padding[i]); } - dst.init_contiguous_stride(); } else if (param().format == Param::Format::NCHW4) { megdnn_assert(src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu", diff --git a/dnn/src/common/tensor_format.cpp b/dnn/src/common/tensor_format.cpp index 33a19b74..c65e5cf8 100644 --- a/dnn/src/common/tensor_format.cpp +++ b/dnn/src/common/tensor_format.cpp @@ -35,8 +35,8 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, case Type::IMAGE2D_PACK4: return Image2DPack4TensorFormat::deserialize( handle, type + 1, bin.size() - sizeof(Type)); - case Type::FOURBITS_ALIGNED_TO_BYTE: - return FourBitsAlignedToBytesTensorFormat::deserialize( + case Type::LOWBITS_ALIGNED_TO_BYTE: + return LowbitsAlignedToBytesTensorFormat::deserialize( handle, type + 1, bin.size() - sizeof(Type)); default: megdnn_throw("invalid tensor format type in deserialize"); @@ -45,6 +45,19 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, TensorFormat::Format() : m_impl{DefaultTensorFormat::make().m_impl} {} +TensorFormat::Format(DType dtype) { + megdnn_assert(dtype.valid()); + if (dtype.is_low_bit()) { + size_t size_nbits = dtype.low_bit(); + megdnn_assert(size_nbits == 1 || size_nbits == 2 || size_nbits == 4, + "unsupported lowbits data type(%s, size in bits: %zu)", + dtype.name(), size_nbits); + m_impl = LowbitsAlignedToBytesTensorFormat::make(size_nbits).m_impl; + } else { + m_impl = DefaultTensorFormat::make().m_impl; + } +} + std::string TensorFormat::to_string() const { return m_impl->to_string(); } @@ -69,6 +82,10 @@ bool TensorFormat::is_default() const { return m_impl == default_tensor_format_obj; } +bool TensorFormat::is_lowbit_aligned() const { + return type() == TensorFormat::Type::LOWBITS_ALIGNED_TO_BYTE; +} + /* ===================== DefaultFormat ===================== */ void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const { megdnn_assert( @@ -440,27 +457,26 @@ template class Image2DPackedTensorFormatBase<4>; } // namespace detail } // namespace megdnn -/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */ -template -LowbitsTensorFormatBase::LowbitsTensorFormatBase( - Type type, size_t align_size_in_bits) - : ImplBase(type), m_align_size_in_bits(align_size_in_bits) { - megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS), +/* =============== LowbitsAlignedTensorFormatBase ============== */ +LowbitsAlignedTensorFormatBase::LowbitsAlignedTensorFormatBase( + Type type, size_t size_nbits, size_t align_size_in_bits) + : ImplBase(type), + m_size_nbits(size_nbits), + m_align_size_in_bits(align_size_in_bits) { + megdnn_assert(!(m_align_size_in_bits % m_size_nbits), "align size(%zu) must be a multiple of element size(%zu)", - m_align_size_in_bits, SIZE_NBITS); - m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS; + m_align_size_in_bits, m_size_nbits); + m_align_size_in_elements = m_align_size_in_bits / m_size_nbits; } -template -std::string LowbitsTensorFormatBase::to_string() const { - return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits); +std::string LowbitsAlignedTensorFormatBase::to_string() const { + return ssprintf("LOWBITS{%zu,%zu}", m_size_nbits, m_align_size_in_bits); } -template -void LowbitsTensorFormatBase::assert_valid( +void LowbitsAlignedTensorFormatBase::assert_valid( const TensorLayout& layout) const { megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && - layout.dtype.low_bit() == SIZE_NBITS); + layout.dtype.low_bit() == m_size_nbits); bool has_dim_unity_stride = false; for (int i = layout.ndim - 1; i >= 0; --i) { if (!has_dim_unity_stride && layout.stride[i] == 1) @@ -469,23 +485,28 @@ void LowbitsTensorFormatBase::assert_valid( layout.stride[i] >= 0 && (layout.stride[i] % m_align_size_in_elements == 0 || layout.stride[i] == 1), - "bad stride: %zu", layout.stride[i]); + "bad stride:%s, %zu", layout.to_string().c_str(), + layout.stride[i]); } - megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous"); + /// FIXME + if (layout.ndim == 0) { + printf("%s\n", layout.to_string().c_str()); + } + megdnn_assert(layout.ndim == 0 || has_dim_unity_stride, + "innermost dim not contiguous"); } -template -void LowbitsTensorFormatBase::serialize_append( +void LowbitsAlignedTensorFormatBase::serialize_append( std::string& result) const { SerializePack pack; + pack.size_nbits = m_size_nbits; pack.align_size_in_bits = m_align_size_in_bits; megdnn_assert(pack.align_size_in_bits == m_align_size_in_bits); // detect overflow; result.append(reinterpret_cast(&pack), sizeof(pack)); } -template -TensorLayout::Span LowbitsTensorFormatBase::span_spec( +TensorLayout::Span LowbitsAlignedTensorFormatBase::span_spec( const TensorLayout& layout) const { assert_valid(layout); if (layout.ndim == 0) @@ -507,8 +528,7 @@ TensorLayout::Span LowbitsTensorFormatBase::span_spec( return TensorLayout::Span(0, 0, high_elem, high_byte); } -template -size_t LowbitsTensorFormatBase::init_contiguous_stride( +size_t LowbitsAlignedTensorFormatBase::init_contiguous_stride( TensorLayout& layout) const { if (!layout.ndim) return 0; @@ -525,8 +545,7 @@ size_t LowbitsTensorFormatBase::init_contiguous_stride( return accum; } -template -bool LowbitsTensorFormatBase::is_contiguous_spec( +bool LowbitsAlignedTensorFormatBase::is_contiguous_spec( const TensorLayout& layout) const { assert_valid(layout); ptrdiff_t expected = 1; @@ -541,8 +560,7 @@ bool LowbitsTensorFormatBase::is_contiguous_spec( return expected != 0; } -template -TensorLayout LowbitsTensorFormatBase::collapse_contiguous_spec( +TensorLayout LowbitsAlignedTensorFormatBase::collapse_contiguous_spec( const TensorLayout& layout) const { assert_valid(layout); TensorLayout res{layout}; @@ -572,12 +590,6 @@ TensorLayout LowbitsTensorFormatBase::collapse_contiguous_spec( return res; } -namespace megdnn { -namespace detail { -template class LowbitsTensorFormatBase<4>; -} // namespace detail -} // namespace megdnn - /* ===================== Image2DPack4TensorFormat ===================== */ TensorFormat Image2DPack4TensorFormat::make_raw( size_t align_axis, size_t align_size_in_elements, @@ -616,29 +628,28 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const { return make_raw(axis, align_size_in_elements(), vendor()); } -/* ===================== FourBitsAlignedToBytesTensorFormat +/* ===================== LowbitsitsAlignedToBytesTensorFormat * ===================== */ -TensorFormat FourBitsAlignedToBytesTensorFormat::make( - size_t align_size_in_bits) { +TensorFormat LowbitsAlignedToBytesTensorFormat::make(size_t size_nbits) { static std::mutex mtx; static std::unordered_map< - uint32_t, std::unique_ptr> + uint64_t, std::unique_ptr> cache; - megdnn_assert(!(align_size_in_bits % 4)); + megdnn_assert(!(8 % size_nbits)); MEGDNN_LOCK_GUARD(mtx); - auto&& ptr = cache[static_cast(align_size_in_bits)]; + auto&& ptr = cache[static_cast(size_nbits)]; if (!ptr) { - ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits}); + ptr.reset(new LowbitsAlignedToBytesTensorFormat{size_nbits}); } return impl_to_tensor_format(ptr.get()); } -TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*, - const void* buf, - size_t size) { +TensorFormat LowbitsAlignedToBytesTensorFormat::deserialize(const Handle*, + const void* buf, + size_t size) { megdnn_assert(size == sizeof(SerializePack)); auto pack = *static_cast(buf); - return make(pack.align_size_in_bits); + return make(pack.size_nbits); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp index 2ae27776..b4be7964 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp @@ -24,6 +24,9 @@ using namespace conv_bias; bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( const SizeArgs& args) const { + if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 && + args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) + return false; if (args.src_layout->dtype == args.filter_layout->dtype && args.src_layout->dtype == dtype::BFloat16()) { return false; diff --git a/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp b/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp index a995f26f..82fc8c7c 100644 --- a/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp +++ b/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp @@ -103,15 +103,18 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, dst_{ws_dst, layouts[4]}; - ExecArgs args_{args.opr, + auto conv_op = args.opr->handle()->create_operator(); + conv_op->param() = args.opr->param(); + using Format = param::ConvBias::Format; + conv_op->param().format = Format::NCHW64; + ExecArgs args_{dynamic_cast(conv_op.get()), src_, filter_, bias_, z_, dst_, - ws.get_workspace(3), - args.preprocessed_filter}; - m_underlying_algo.exec(args); + ws.get_workspace(3)}; + m_underlying_algo.exec(args_); // reformat dst nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout}); } @@ -134,6 +137,9 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( rst.emplace_back(TensorLayout{}); } rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, dst.dtype}); + for (auto& i : rst) { + i.init_contiguous_stride(); + } return rst; } @@ -145,13 +151,16 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( auto layouts = make_underlying_tensor_layout( *(args.src_layout), *(args.filter_layout), *(args.bias_layout), *(args.z_layout), *(args.dst_layout)); - SizeArgs args_{args.opr, + auto conv_op = args.opr->handle()->create_operator(); + conv_op->param() = args.opr->param(); + using Format = param::ConvBias::Format; + conv_op->param().format = Format::NCHW64; + SizeArgs args_{dynamic_cast(conv_op.get()), layouts[0], layouts[1], layouts[2], layouts[3], - layouts[4], - args.preprocessed_filter}; + layouts[4]}; size_t ws_size_underlying_algo = m_underlying_algo.get_workspace_in_bytes(args_); if (args.z_layout->ndim > 0) { diff --git a/dnn/src/cuda/conv_bias/helper.cpp b/dnn/src/cuda/conv_bias/helper.cpp index 02955adc..849df376 100644 --- a/dnn/src/cuda/conv_bias/helper.cpp +++ b/dnn/src/cuda/conv_bias/helper.cpp @@ -136,6 +136,10 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param, namespace conv_bias { bool is_cudnn_supported(const BiasForwardSizeArgs& args) { + if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 && + args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) + return false; + if (args.src_layout->dtype == args.filter_layout->dtype && args.src_layout->dtype == dtype::BFloat16()) { return false; diff --git a/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp index fbc851c5..3b50a10b 100644 --- a/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp @@ -72,11 +72,11 @@ std::string ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::kernel_key( auto&& param = args.opr->param(); if (args.z_layout->ndim > 0) { kernel_key = - ssprintf("%s_conv_bias_int4_fuse_z_imma_ldg16_%ux%u", + ssprintf("%s_conv_bias_int4_fuse_z_imma8832_ldg16_%ux%u", current_device_arch_name(), m_tile_nhw, m_tile_oc); } else { kernel_key = - ssprintf("%s_conv_bias_int4_imma_ldg16_%ux%u", + ssprintf("%s_conv_bias_int4_imma8832_ldg16_%ux%u", current_device_arch_name(), m_tile_nhw, m_tile_oc); } if (param.nonlineMode == NonlineMode::H_SWISH) { @@ -170,7 +170,7 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec( reorder_imma_filter_bias<4, 64>( reinterpret_cast(filter_ptr), reinterpret_cast(bias_ptr), - args.filter_tensor->compatible_ptr(), + reinterpret_cast(args.filter_tensor->raw_ptr), args.bias_tensor->compatible_ptr(), co, ci, fh, fw, stream); } @@ -292,9 +292,10 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec_preprocess( param); auto&& stream = cuda_stream(args.opr->handle()); reorder_imma_filter_bias<4, 64>( - args.preprocessed_filter->tensors[0].compatible_ptr(), + reinterpret_cast( + args.preprocessed_filter->tensors[0].raw_ptr), args.preprocessed_filter->tensors[1].compatible_ptr(), - args.filter_tensor->compatible_ptr(), + reinterpret_cast(args.filter_tensor->raw_ptr), args.bias_tensor->compatible_ptr(), co, ci, fh, fw, stream); } diff --git a/dnn/test/common/test_basic_types.cpp b/dnn/test/common/test_basic_types.cpp index d0b2b285..3df876eb 100644 --- a/dnn/test/common/test_basic_types.cpp +++ b/dnn/test/common/test_basic_types.cpp @@ -320,7 +320,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS) { layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1}, dtype::QuantizedS4{1.3f}); - layout.format = FourBitsAlignedToBytesTensorFormat::make(8_z); + layout.format = LowbitsAlignedToBytesTensorFormat::make(4_z); EXPECT_TRUE(layout.is_contiguous()); layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; @@ -339,12 +339,10 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS_VALID) { DefaultTensorFormat::make()), MegDNNError); ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f}, - FourBitsAlignedToBytesTensorFormat::make(8_z)) - .span(), + LowbitsAlignedToBytesTensorFormat::make(4_z)), MegDNNError); ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{}, - FourBitsAlignedToBytesTensorFormat::make(8_z)) - .span(), + LowbitsAlignedToBytesTensorFormat::make(2_z)), MegDNNError); } diff --git a/src/core/impl/graph/operator_node.cpp b/src/core/impl/graph/operator_node.cpp index 55d3d8f4..1197fd73 100644 --- a/src/core/impl/graph/operator_node.cpp +++ b/src/core/impl/graph/operator_node.cpp @@ -338,21 +338,26 @@ void OperatorNodeBase::init_output_format() { TensorFormat format, default_; for (auto i : input()) { auto cur = i->format(); - if (cur != default_) { + if (!cur.is_default() && !cur.is_lowbit_aligned()) { if (format == default_) { format = cur; } else { mgb_assert(format == cur, - "multiple non-default formats in inputs: %s vs %s", + "multiple non-default or non-lowbits aligned " + "formats in inputs: %s vs %s", format.to_string().c_str(), cur.to_string().c_str()); } } } for (auto i : output()) { if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - i->format(default_); + mgb_assert(format.is_default()); + i->format(TensorFormat(i->dtype())); } else { - i->format(format); + if (!format.is_default()) + i->format(format); + else + i->format(TensorFormat(i->dtype())); } } } diff --git a/src/core/impl/graph/var_node_mem_mgr.cpp b/src/core/impl/graph/var_node_mem_mgr.cpp index 7bb130a0..12b9d386 100644 --- a/src/core/impl/graph/var_node_mem_mgr.cpp +++ b/src/core/impl/graph/var_node_mem_mgr.cpp @@ -1063,15 +1063,22 @@ bool VarNodeMemManager::fwd_in2out_readonly( return false; } - mgb_assert( - src != dest && - src->comp_node().mem_node() == dest->comp_node().mem_node() && - dest->m_mem_plan.valid() && src->m_mem_plan.valid() && - dest->m_mem_plan.layout().eq_shape(sub.layout()) && - dest->m_mem_plan.layout().dtype.size() == sub.layout().dtype.size() - ); - assert_in_mem_opt_phase( - SeqMemOptimizer::Status::ALLOW_FWD_IN2OUT_READONLY); + bool cond_low_bit = dest->m_mem_plan.layout().dtype.is_low_bit() && + sub.layout().dtype.is_low_bit() && + dest->m_mem_plan.layout().dtype.low_bit() == + sub.layout().dtype.low_bit(); + bool cond_normal = + !dest->m_mem_plan.layout().dtype.is_low_bit() && + !sub.layout().dtype.is_low_bit() && + dest->m_mem_plan.layout().dtype.size() == sub.layout().dtype.size(); + MGB_MARK_USED_VAR(cond_low_bit); + MGB_MARK_USED_VAR(cond_normal); + mgb_assert(src != dest && + src->comp_node().mem_node() == dest->comp_node().mem_node() && + dest->m_mem_plan.valid() && src->m_mem_plan.valid() && + dest->m_mem_plan.layout().eq_shape(sub.layout()) && + (cond_normal || cond_low_bit)); + assert_in_mem_opt_phase(SeqMemOptimizer::Status::ALLOW_FWD_IN2OUT_READONLY); if (!m_owner_graph->options().seq_opt.enable_mem_plan_opt) return false; diff --git a/src/core/impl/tensor.cpp b/src/core/impl/tensor.cpp index 7dcb2311..a4f03517 100644 --- a/src/core/impl/tensor.cpp +++ b/src/core/impl/tensor.cpp @@ -443,8 +443,8 @@ TensorND::name DEF(resize, &)(const TensorShape& shape) { mgb_assert(m_layout.dtype.valid()); - auto nr_elems = m_layout.init_contiguous_stride(shape); - m_storage.ensure_size(m_layout.dtype.size(nr_elems)); + m_layout = TensorLayout(shape, m_layout.dtype); + m_storage.ensure_size(m_layout.span().dist_byte()); return static_cast(*this); } @@ -584,15 +584,19 @@ TensorND::copy_from(const TensorND &src) { m_layout.dtype.assert_is(src.dtype()); else m_layout.dtype = src.dtype(); - m_layout.format = {}; - size_t size_bytes = dtype().size( - m_layout.init_contiguous_stride(src.shape())); + m_layout = TensorLayout(src.shape(), m_layout.dtype); + size_t size_bytes = m_layout.span().dist_byte(); m_storage.ensure_size(size_bytes); if (!size_bytes) { return static_cast(*this); } - if (src.layout().is_physical_contiguous()) { + // requirement: + // default case, physical contiguous + // lowbit aligned, logical contiguous + if (src.layout().is_physical_contiguous() || + (src.layout().format.is_lowbit_aligned() && + src.layout().is_contiguous())) { if (should_check_overlap(*this, src)) { check_overlapped(m_storage.ptr(), m_storage.ptr() + size_bytes, @@ -635,10 +639,17 @@ TensorND::copy_from_fixlayout( src.raw_ptr() + src_span.high_byte); } - bool self_contig = m_layout.is_physical_contiguous(), - src_contig = src.layout().is_physical_contiguous(); + bool self_contig = m_layout.is_physical_contiguous() || + (m_layout.format.is_lowbit_aligned() && + m_layout.is_contiguous()), + src_contig = src.layout().is_physical_contiguous() || + (m_layout.format.is_lowbit_aligned() && + m_layout.is_contiguous()); if (self_contig && src_contig) { - if (m_layout.format.is_default() && src.layout().format.is_default()) { + if ((m_layout.format.is_default() && + src.layout().format.is_default()) || + (m_layout.format.is_lowbit_aligned() && + src.layout().format.is_lowbit_aligned())) { mgb_assert(src_span.low_byte == 0 && dst_span.low_byte == 0 && src_span.high_byte == dst_span.high_byte); m_storage.copy_from(src.storage(), src_span.high_byte); diff --git a/src/core/impl/utils/persistent_cache.cpp b/src/core/impl/utils/persistent_cache.cpp index 4645fbd7..b5db023e 100644 --- a/src/core/impl/utils/persistent_cache.cpp +++ b/src/core/impl/utils/persistent_cache.cpp @@ -261,7 +261,8 @@ PersistentCache::Blob AlgoChooserProfileCache::Key::build_blob() const { ret.push_back(';'); ret.append(ly.dtype.name()); ret.push_back('|'); - mgb_assert(ly.format.is_default(), + mgb_assert(ly.format.is_default() || (ly.format.is_lowbit_aligned() && + ly.dtype.is_low_bit()), "currently only default format is supported"); } if (m_param_size) { diff --git a/src/core/include/megbrain/tensor.h b/src/core/include/megbrain/tensor.h index e22133d1..93bd789f 100644 --- a/src/core/include/megbrain/tensor.h +++ b/src/core/include/megbrain/tensor.h @@ -68,7 +68,10 @@ class SubTensorSpec { //! get offset measured in bytes ptrdiff_t offset_byte() const { - return m_offset_elem * m_layout.dtype.size(); + //! for lowbit cases, offset must aligned to bytes + mgb_assert(!m_layout.dtype.is_low_bit() || + !(m_offset_elem * m_layout.dtype.low_bit() % 8)); + return m_layout.dtype.size(m_offset_elem); } /*! diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 5a638efa..a9ba68a4 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -554,14 +554,16 @@ void ParamFusePass::apply(OptState &state) const { SymbolVar new_var; bool is_default_format = var->format().is_default(); - if (cg::is_static_var_value(var) && is_default_format) { + bool is_lowbit_aligned = var->format().is_lowbit_aligned(); + if (cg::is_static_var_value(var) && + (is_default_format || is_lowbit_aligned)) { // use ImmutableTensor for inferable vars HostTensorND hv; hv.copy_from(*inferred_val).sync(); new_var = opr::ImmutableTensor::make( *var->owner_graph(), hv, var_namer.name(var)); } else { - if (is_default_format) { + if (is_default_format || is_lowbit_aligned) { new_var = opr::SharedDeviceTensor::make_const( *var->owner_graph(), inferred_val, var_namer.name(var)); } else { diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index aa6b0c58..c3d29398 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -814,8 +814,13 @@ MGB_IMPL_OPR_GRAD(TypeCvt) { #endif void TypeCvt::mem_plan_fwd_in2out_writable() { - if (input(0)->dtype().size() == output(0)->dtype().size() && - input(0)->layout().is_contiguous()) { + bool cond_low_bit = + input(0)->dtype().is_low_bit() && output(0)->dtype().is_low_bit() && + input(0)->dtype().low_bit() == output(0)->dtype().low_bit(); + bool cond_normal = !input(0)->dtype().is_low_bit() && + !output(0)->dtype().is_low_bit() && + input(0)->dtype().size() == output(0)->dtype().size(); + if ((cond_low_bit || cond_normal) && input(0)->layout().is_contiguous()) { output(0)->set_fwd_in2out_writable(input(0)); } } diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index 0d9fa34c..316e4969 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -120,12 +120,11 @@ public: explicit DevValueExecDep(DeviceTensorStorage val) : m_val{std::move(val)} {} }; - void intl::DeviceTensorHolder::init_output_format() { auto format = get_dev_tensor().format(); - mgb_assert(format.is_default(), "non-default tensor format: %s", - format.to_string().c_str()); - // no need to set output foramt since it is initialized as default + mgb_assert(format.is_default() || format.is_lowbit_aligned(), + "invalid tensor format: %s", format.to_string().c_str()); + output(0)->format(format); } void intl::DeviceTensorHolder::init_output_mem_plan(bool dynamic) { diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 57a0d3f6..32900597 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -638,10 +638,18 @@ AlgoChooser::AlgoChooserHelper::profile_single_algo( param.workspace = get_workspace_size_bytes(policy); for (int i = 0; i < arity; ++i) { auto&& src = m_layouts[i]; - mgb_assert(src.format.is_default() && + bool cond_normal = src.format.is_default() && (src.dtype.category() == DTypeCategory::FLOAT || src.dtype.category() == DTypeCategory::INT || - src.dtype.category() == DTypeCategory::QUANTIZED), + src.dtype.category() == DTypeCategory::QUANTIZED); + + bool cond_low_bit = src.dtype.is_low_bit() && + src.format.is_lowbit_aligned() && + (src.dtype.category() == DTypeCategory::QUANTIZED || + src.dtype.category() == DTypeCategory::LOWBIT); + MGB_MARK_USED_VAR(cond_normal); + MGB_MARK_USED_VAR(cond_low_bit); + mgb_assert(cond_normal || cond_low_bit, "unsupported layout in profiling: %s", src.to_string().c_str()); param.dtypes[i] = src.dtype.enumv(); diff --git a/src/opr/impl/search_policy/profiler.cpp b/src/opr/impl/search_policy/profiler.cpp index b462f7b1..705c02b2 100644 --- a/src/opr/impl/search_policy/profiler.cpp +++ b/src/opr/impl/search_policy/profiler.cpp @@ -175,15 +175,17 @@ typename TimedProfiler::TResult TimedProfiler::prof_impl( case DTypeTrait<_dt>::enumv: \ return _dt(1.0f, static_cast(0)) cb(dtype::Quantized8Asymm); + cb(dtype::Quantized4Asymm); #undef cb #define cb(_dt) \ case DTypeTrait<_dt>::enumv: \ return _dt(1.0f) - + cb(dtype::QuantizedS8); cb(dtype::QuantizedS16); cb(dtype::QuantizedS32); + cb(dtype::QuantizedS4); default: return DType::from_enum(enumv); #undef cb diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index 021ddf34..880ea6a5 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -2603,4 +2603,306 @@ TEST_F(TestNoWeightPreprocess, NoPreprocess) { #endif +namespace { +// FIXME change comp node from "cpu0" to "gpu0" +TEST(TestOprDNN, ConvBiasInt4NCHW) { + auto run = [](size_t N, size_t C, size_t H, size_t W, size_t F, size_t S, + size_t P) { + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + + HostTensorGenerator gen; + auto mkvar = [&gen](const char* name, const TensorShape& shp, + const DType& dtype, + std::shared_ptr graph, + const CompNode& cn) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)) + .rename(name), + dtype); + }; + auto mkcvar = [&gen](const char* name, const TensorShape& shp, + const DType& dtype, + std::shared_ptr graph, + const CompNode& cn) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + using Policy = opr::ConvBias::ExecutionPolicy; + using Strategy = Policy::Strategy; + auto x = mkvar("x", {N, C * 4, H, W}, dtype::QuantizedS4(1.19960327f), + graph, cn), + w = mkcvar("w1", {C, C * 4, F, F}, dtype::QuantizedS4(1.19970327f), + graph, cn), + b = mkcvar("b1", {1, C, 1, 1}, + dtype::QuantizedS32(1.19960327f * 1.19970327f), graph, + cn); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + param.stride_h = param.stride_w = S; + param.pad_h = param.pad_w = P; + Policy policy; + policy.strategy = Strategy::PROFILE; + + auto y = opr::ConvBias::make( + x, w, b, param, policy, + OperatorNodeConfig{dtype::QuantizedS4(11.9960501f)}); + y = opr::TypeCvt::make(y, dtype::Float32()); + auto x_f32 = opr::TypeCvt::make(x, dtype::Float32()), + w_f32 = opr::TypeCvt::make(w, dtype::Float32()), + b_f32 = opr::TypeCvt::make(b, dtype::Float32()); + auto y_f32 = opr::ConvBias::make(x_f32, w_f32, b_f32, param, policy); + auto y_q4 = opr::TypeCvt::make(y_f32, dtype::QuantizedS4{11.9960501f}); + y_q4 = opr::TypeCvt::make(y_q4, dtype::Float32()); + HostTensorND host_y, host_y_q4; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_q4, host_y_q4)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_q4, 1e-3); + }; + run(2, 64, 14, 14, 3, 2, 1); + run(2, 64, 7, 7, 3, 1, 1); + run(2, 64, 14, 14, 1, 2, 0); + run(2, 64, 7, 7, 1, 1, 0); +} + +TEST(TestOprDNN, ConvBiasInt4NCHW64) { + auto nchw2nchw64 = [](SymbolVar x) { + auto y = opr::RelayoutFormat::make( + x, opr::RelayoutFormat::Param::Mode::NCHW_NCHW64); + return y; + }; + + auto nchw642nchw = [](SymbolVar x) { + auto y = opr::RelayoutFormat::make( + x, opr::RelayoutFormat::Param::Mode::NCHW64_NCHW); + return y; + }; + + auto run = [&](size_t N, size_t C, size_t H, size_t W, size_t F, size_t S, + size_t P) { + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + + HostTensorGenerator gen; + auto mkvar = [&gen](const char* name, const TensorShape& shp, + const DType& dtype, + std::shared_ptr graph, + const CompNode& cn) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)) + .rename(name), + dtype); + }; + auto mkcvar = [&gen](const char* name, const TensorShape& shp, + const DType& dtype, + std::shared_ptr graph, + const CompNode& cn) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + using Policy = opr::ConvBias::ExecutionPolicy; + using Strategy = Policy::Strategy; + auto x = mkvar("x", {N, C / 16, H, W, 64}, + dtype::QuantizedS4(1.19960327f), graph, cn), + w = mkcvar("w1", {C, C / 16, F, F, 64}, + dtype::QuantizedS4(1.19970327f), graph, cn), + b = mkcvar("b1", {1, C / 64, 1, 1, 64}, + dtype::QuantizedS32(1.19960327f * 1.19970327f), graph, + cn); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW64; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + param.stride_h = param.stride_w = S; + param.pad_h = param.pad_w = P; + Policy policy; + policy.strategy = Strategy::PROFILE; + + auto y = opr::ConvBias::make( + x, w, b, param, policy, + OperatorNodeConfig{dtype::QuantizedS4(11.9960501f)}); + y = opr::TypeCvt::make(y, dtype::Float32()); + x = nchw642nchw(x); + w = nchw642nchw(w); + b = nchw642nchw(b); + auto x_f32 = opr::TypeCvt::make(x, dtype::Float32()), + w_f32 = opr::TypeCvt::make(w, dtype::Float32()), + b_f32 = opr::TypeCvt::make(b, dtype::Float32()); + param.format = opr::ConvBias::Param::Format::NCHW; + auto y_f32 = opr::ConvBias::make(x_f32, w_f32, b_f32, param, policy); + auto y_q4 = opr::TypeCvt::make(y_f32, dtype::QuantizedS4{11.9960501f}); + y_q4 = opr::TypeCvt::make(y_q4, dtype::Float32()); + y_q4 = nchw2nchw64(y_q4); + HostTensorND host_y, host_y_q4; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_q4, host_y_q4)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_q4, 1e-3); + }; + run(2, 64, 14, 14, 3, 2, 1); + run(2, 64, 7, 7, 3, 1, 1); + run(2, 64, 14, 14, 1, 2, 0); + run(2, 64, 7, 7, 1, 1, 0); +} + +TEST(TestOprDNN, ConvBiasInt4Serialize) { + using namespace serialization; + + float inp_scale = 1.20210327f; + float filt_scale = 1.20210406f; + float bias_scale = inp_scale * filt_scale; + DType output_dtype = dtype::QuantizedS4{inp_scale}; + + HostTensorGenerator gen; + std::shared_ptr xv; + auto mkvar = [&gen](const char* name, const DType& dtype, + std::shared_ptr graph, + std::shared_ptr val) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, val).rename(name), dtype); + }; + auto mkcvar = + [&gen](const char* name, const TensorShape& shp, const DType& dtype, + std::shared_ptr graph, const CompNode& cn) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto fname = output_file("ConvBiasInt4Serialize"); + HostTensorND y1, y2; + auto dump = [&]() { + opr::ConvBias::Param param; + param.mode = Mode::CONVOLUTION; + + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + xv = gen({1, 64, 56, 56}, cn); + auto x = mkvar("x", dtype::QuantizedS4{inp_scale}, graph, xv); + auto w = mkcvar("w", {256, 64, 1, 1}, dtype::QuantizedS4{filt_scale}, graph, cn); + auto b = mkcvar("b", {1, 256, 1, 1}, dtype::QuantizedS32{bias_scale}, graph, cn); + auto y = opr::ConvBiasForward::make(x, w, b, param, {}, + OperatorNodeConfig{output_dtype}); + auto w1 = mkcvar("w1", {64, 256, 1, 1}, dtype::QuantizedS4{filt_scale}, + graph, cn); + auto b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32{bias_scale}, + graph, cn); + y = opr::ConvBiasForward::make(y, w1, b1, param, {}, + OperatorNodeConfig{output_dtype}); + y = opr::TypeCvt::make(y, dtype::Float32()); + auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str())); + auto func = graph->compile({make_callback_copy(y, y1)}); + func->execute(); + func->wait(); + auto rst = dumper->dump({y}); + ASSERT_EQ(rst.outputs.size(), 1u); + }; + + auto load = [&]() { + auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str())); + auto rst = loader->load(); + for (const auto& t : rst.tensor_map) { + t.second->copy_from(*xv).sync(); + } + auto func = rst.graph->compile( + {make_callback_copy(rst.output_var_list[0], y2)}); + func->execute(); + func->wait(); + ASSERT_EQ(rst.output_var_list.size(), 1u); + EXPECT_EQ(rst.output_var_list[0].dtype(), dtype::Float32()); + }; + + dump(); + load(); + MGB_ASSERT_TENSOR_NEAR(y1, y2, 1e-3); +} + +TEST(TestOprDNN, ConvBiasInt4SerializeWithParamFuse) { + using namespace serialization; + + float inp_scale = 1.20210327f; + float filt_scale = 1.20210406f; + float bias_scale = inp_scale * filt_scale; + DType output_dtype = dtype::QuantizedS4{inp_scale}; + + HostTensorGenerator gen; + std::shared_ptr xv; + auto mkvar = [&gen](const char* name, const DType& dtype, + std::shared_ptr graph, + std::shared_ptr val) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, val).rename(name), dtype); + }; + auto mkcvar = + [&gen](const char* name, const TensorShape& shp, const DType& dtype, + std::shared_ptr graph, const CompNode& cn) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto fname = output_file("ConvBiasInt4SerializeWithParamFuse"); + HostTensorND y1, y2; + auto dump = [&]() { + opr::ConvBias::Param param; + param.mode = Mode::CONVOLUTION; + + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + xv = gen({1, 64, 56, 56}, cn); + auto x = mkvar("x", dtype::QuantizedS4{inp_scale}, graph, xv); + auto w = mkcvar("w", {256, 64, 1, 1}, dtype::QuantizedS4{filt_scale}, graph, cn); + auto b = mkcvar("b", {1, 256, 1, 1}, dtype::QuantizedS32{bias_scale}, graph, cn); + auto y = opr::ConvBiasForward::make(x, w, b, param, {}, + OperatorNodeConfig{output_dtype}); + auto w1 = mkcvar("w1", {64, 256, 1, 1}, dtype::QuantizedS4{filt_scale}, + graph, cn); + auto b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32{bias_scale}, + graph, cn); + y = opr::ConvBiasForward::make(y, w1, b1, param, {}, + OperatorNodeConfig{output_dtype}); + y = opr::TypeCvt::make(y, dtype::Float32()); + SymbolVar y_param_fused; + unpack_vector(gopt::GraphOptimizer{} + .add_pass() + .apply({{y}}) + .endpoint_vars(), + y_param_fused); + auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str())); + auto func = graph->compile({make_callback_copy(y_param_fused, y1)}); + func->execute(); + func->wait(); + auto rst = dumper->dump({y_param_fused}); + ASSERT_EQ(rst.outputs.size(), 1u); + }; + + auto load = [&]() { + auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str())); + auto rst = loader->load(); + for (const auto& t : rst.tensor_map) { + t.second->copy_from(*xv).sync(); + } + auto func = rst.graph->compile( + {make_callback_copy(rst.output_var_list[0], y2)}); + func->execute(); + func->wait(); + ASSERT_EQ(rst.output_var_list.size(), 1u); + EXPECT_EQ(rst.output_var_list[0].dtype(), dtype::Float32()); + }; + + dump(); + load(); + MGB_ASSERT_TENSOR_NEAR(y1, y2, 1e-3); +} +} // namespace + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab