diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index bd23381234657bcab6b45b4a6714ebe572117829..eda63b4a6e710a221acba71bfb2002b2e4b77252 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -963,11 +963,21 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight( bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( const cg::OperatorNodeBase& opr) const { - bool param_merged = opr.input(1) - ->owner_opr() - ->same_type(); - return opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) && - (cg::is_const_var_value(opr.input(1)) || param_merged); + if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) + return false; + if (cg::is_const_var_value(opr.input(1))) + return true; + auto* input_opr = opr.input(1)->owner_opr(); + if (input_opr->same_type() || + input_opr->same_type()) + return true; + auto* sdt = input_opr->try_cast_final(); + if (sdt && sdt->const_value()) + return true; + auto* sdtf = input_opr->try_cast_final(); + if (sdtf && sdtf->const_value()) + return true; + return false; } /* ==================== ConvolutionForward ==================== */ diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index 279f26e2da61be766d3deb6eea2e6d1ea6b7d358..bff006e1057dbec873352e837107902898da9e95 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -307,20 +307,6 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() { comp_node(m_dev_data->comp_node()); } -bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) { - if (m_const_value) { - if (dest) { - if (m_static_infer.empty()) { - m_static_infer.comp_node(CompNode::default_cpu()) - .copy_from(*m_dev_data); - } - *dest = m_static_infer; - } - return true; - } - return false; -} - cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { return cg::static_infer::SourceType::CONSTANT; } @@ -886,24 +872,6 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() { }; mgr.register_shape_infer(output(i), {SourceType::CONSTANT, {}, infer_shp}); - - auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) { - if (m_host_values.empty()) { - m_host_values.resize(m_values.size()); - } - if (m_host_values[i].empty()) { - m_host_values[i] - .comp_node(CompNode::default_cpu()) - .copy_from(*m_values[i]); - } - if (!m_host_values[i].empty()) { - dest = m_host_values[i]; - return true; - } - return false; - }; - mgr.register_value_infer(output(i), - {SourceType::CONSTANT, {}, infer_val}); } } diff --git a/src/opr/include/megbrain/opr/io.h b/src/opr/include/megbrain/opr/io.h index 3be2c25155c5199a47f0bfaa70bd528bbd9205e0..caad0f7c8bdb616a463f34fd4c9da5fae100dccb 100644 --- a/src/opr/include/megbrain/opr/io.h +++ b/src/opr/include/megbrain/opr/io.h @@ -75,12 +75,14 @@ class DeviceTensorHolder: public HostIONodeBase { */ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { std::shared_ptr m_dev_data; - DeviceTensorND m_static_infer; bool m_const_value; const TensorShape& get_output_shape() override; - bool fill_in_static_infer(DeviceTensorND* dest) override; + bool fill_in_static_infer(DeviceTensorND* dest) override { + MGB_MARK_USED_VAR(dest); + return false; + } void init_output_comp_node() override; @@ -131,8 +133,6 @@ private: void init_output_comp_node() override; void init_output_static_infer_desc() override; NodeProp* do_make_node_prop() const override; - - SmallVector m_host_values; }; } // namespace intl