diff --git a/dnn/src/common/named_tensor.cpp b/dnn/src/common/named_tensor.cpp index 919145c55425bd1fbdc64ff8d2b3b957fc83674a..9e952475d36364b9c56aeac8b284fbc399c987c1 100644 --- a/dnn/src/common/named_tensor.cpp +++ b/dnn/src/common/named_tensor.cpp @@ -246,6 +246,8 @@ NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) { return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}}; case Format::NCHW44_DOT: return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}}; + case Format::NHWCD4: + return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}}; default: megdnn_throw(ssprintf("Format unimplement(%d)", static_cast(format)) .c_str()); diff --git a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp index 6ea806bcb42c58989ec823a99469c277e21f7a57..b0e93da8a55b2a3258fb0e5d1e9dbc28f23daf89 100644 --- a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp +++ b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp @@ -229,6 +229,30 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { } }; +template <> +struct OprSingleInOutTensorFormatsDispatcherImpl { + static Maybe dispatch(const OperatorNodeBase* opr) { + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NHWCD4; + config.config_id = OprFormatConfigID::NHWCD4; + bool available = + opr->input(0)->dtype().enumv() == DTypeEnum::Float32 || + DNN_FLOAT16_SELECT( + (opr->input(0)->dtype().enumv() == DTypeEnum::Float16), true) || + opr->input(0)->dtype().enumv() == DTypeEnum::Int8 || + opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; + config.input_dtypes = {opr->input(0)->dtype().enumv()}; + config.input_tensor_types = {TensorType::FEATURE}; + config.output_dtypes = {opr->output(0)->dtype().enumv()}; + config.input_tensor_formats = {TensorFormats::NHCWc4}; + config.output_tensor_formats = {TensorFormats::NHCWc4}; + if (available) + return config; + return None; + } +}; + template struct ConvTensorFormatsDispatcherImpl; @@ -814,6 +838,55 @@ struct ConvTensorFormatsDispatcherImpl +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch(const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NHWCD4; + config.config_id = OprFormatConfigID::NHWCD4; + for (size_t i = 0; i < opr->input().size(); ++i) { + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + if (conv.param().sparse == Opr::Param::Sparse::DENSE) { + if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { + config.input_tensor_formats = { + TensorFormats::NHCWc4, TensorFormats::KRSCk4c4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } else { + config.input_tensor_formats = { + TensorFormats::NHCWc4, TensorFormats::KRSCk4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } + } else { + mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); + if (is_channel_wise_conv(opr)) { + config.input_tensor_formats = { + TensorFormats::NHCWc4, TensorFormats::C1RSc4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } else { + if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { + config.input_tensor_formats = { + TensorFormats::NHCWc4, TensorFormats::GKRSCk4c4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } else { + config.input_tensor_formats = { + TensorFormats::NHCWc4, TensorFormats::GKRSCk4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } + } + } + config.output_tensor_formats = {TensorFormats::NHCWc4}; + return config; + } +}; + template <> struct ConvTensorFormatsDispatcherImpl< opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> { @@ -919,6 +992,57 @@ struct ConvTensorFormatsDispatcherImpl< } }; +template <> +struct ConvTensorFormatsDispatcherImpl< + opr::ConvolutionBackwardData, OprFormatConfigID::NHWCD4> { + using Opr = opr::ConvolutionBackwardData; + static Maybe dispatch(const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NHWCD4; + config.config_id = OprFormatConfigID::NHWCD4; + for (size_t i = 0; i < opr->input().size(); ++i) { + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + if (conv.param().sparse == Opr::Param::Sparse::DENSE) { + if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { + config.input_tensor_formats = { + TensorFormats::KRSCk4c4, TensorFormats::NHCWc4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } else { + config.input_tensor_formats = { + TensorFormats::KRSCk4, TensorFormats::NHCWc4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } + } else { + mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); + if (is_channel_wise_conv(opr)) { + config.input_tensor_formats = { + TensorFormats::C1RSc4, TensorFormats::NHCWc4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } else { + if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { + config.input_tensor_formats = { + TensorFormats::GKRSCk4c4, TensorFormats::NHCWc4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } else { + config.input_tensor_formats = { + TensorFormats::GKRSCk4, TensorFormats::NHCWc4, + TensorFormats::NHCWc4, TensorFormats::NHCWc4}; + } + } + } + config.output_tensor_formats = {TensorFormats::NHCWc4}; + return config; + } +}; + struct StaticData { struct KeyHash { size_t operator()(const std::pair& val) const { @@ -969,6 +1093,7 @@ StaticData::StaticData() { OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NHWCD4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWC); @@ -979,15 +1104,18 @@ StaticData::StaticData() { OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWCD4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW4); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWCD4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWC); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW64); + OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWCD4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWC); @@ -997,10 +1125,12 @@ StaticData::StaticData() { OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); + OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWCD4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); + OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NHWCD4); #undef OPR_TENSOR_FORMATS_CONFIG_REG #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG diff --git a/src/gopt/impl/global_layout_transform/profiler_impl.cpp b/src/gopt/impl/global_layout_transform/profiler_impl.cpp index b1ef404b118f9dfeedf382df3fc7426cdc8b7755..48d7188a31b4fa1b20b2cf68e8fee6e90a8458a2 100644 --- a/src/gopt/impl/global_layout_transform/profiler_impl.cpp +++ b/src/gopt/impl/global_layout_transform/profiler_impl.cpp @@ -22,6 +22,7 @@ #include "megbrain/opr/tensor_manip.h" #include "megbrain/plugin/base.h" #include "megbrain/serialization/sereg.h" +#include "megdnn/tensor_format.h" using namespace mgb; using namespace cg; @@ -281,9 +282,6 @@ float ProfilerImpl::profile_operator( std::min(config.input_tensor_formats.size(), opr->input().size()); for (; i < nr_input_tensor; ++i) { auto&& var = opr->input(i); - auto&& cn = var->comp_node(); - auto&& dtype = var->dtype(); - auto dval = std::make_shared(cn, dtype); TensorShape aligned_shape; if (config.input_tensor_types[i] == TensorType::WEIGHT) { mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); @@ -299,9 +297,12 @@ float ProfilerImpl::profile_operator( var, base_config.input_tensor_formats[i], config.input_tensor_formats[i], extra_attribute); } - dval->resize(aligned_shape); + std::shared_ptr dval = create_device_tensor_helper( + config, i, var, aligned_shape, extra_attribute); + if (config.input_tensor_types[i] == TensorType::WEIGHT) { - new_inps[i] = opr::SharedDeviceTensor::make_const(*graph, dval).node(); + new_inps[i] = + opr::SharedDeviceTensorWithFormat::make_const(*graph, dval).node(); } else { new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node(); } @@ -368,10 +369,27 @@ float ProfilerImpl::profile_var_node( const VarNode* var, TensorFormats base_format, const ReformatKey& key) const { auto&& cn = var->comp_node(); auto&& dtype = var->dtype(); - auto dval = std::make_shared(cn, dtype); auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( var, base_format, key.input_format, key.attribute); - dval->resize(aligned_tensor_shape); + + std::shared_ptr dval; + if (key.input_format == TensorFormats::NHCWc4 && + key.attribute & ReformatAttribute::IMAGE2D) { + size_t align_axis = 2; + auto named_tensor = tensor_formats_to_named_tensor_shape(key.input_format); + for (size_t n = 0; n < named_tensor.ndim; n++) { + if (named_tensor[n].name() == megdnn::Dimension::Name::C) { + align_axis = n; + break; + } + } + dval = std::make_shared( + cn, aligned_tensor_shape, dtype, + megdnn::Image2DPack4TensorFormat::make( + align_axis, opr::intl::get_megdnn_handle(cn))); + } else + dval = std::make_shared(cn, aligned_tensor_shape, dtype); + auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; graph->options().var_sanity_check_first_run = false; @@ -516,6 +534,8 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( return OprFormatConfigID::NHWC; case TensorFormats::CHWNc4: return OprFormatConfigID::CHWN4; + case TensorFormats::NHCWc4: + return OprFormatConfigID::NHWCD4; default: mgb_throw( MegBrainError, "tensor format(%u) is not supported", @@ -523,6 +543,39 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( } } +std::shared_ptr ProfilerImpl::create_device_tensor_helper( + const OprTensorFormatsConfiguration& config, const size_t inp_idx, + const VarNode* var, const TensorShape aligned_shape, + ReformatAttribute extra_attribute) const { + auto&& cn = var->comp_node(); + auto&& dtype = var->dtype(); + std::shared_ptr dval; + if (config.config_id == OprFormatConfigID::NHWCD4 && + extra_attribute & ReformatAttribute::IMAGE2D) { + size_t align_axis = 2; + auto named_tensor = tensor_formats_to_named_tensor_shape( + config.input_tensor_formats[inp_idx]); + for (size_t n = 0; n < named_tensor.ndim; n++) { + if (named_tensor[n].name() == megdnn::Dimension::Name::C) { + align_axis = n; + break; + } + } + // channel wise weight + bool is_channel_wise = + config.input_tensor_formats[inp_idx] == TensorFormats::C1RSc4; + if (is_channel_wise) + align_axis = 1; + dval = std::make_shared( + cn, aligned_shape, dtype, + megdnn::Image2DPack4TensorFormat::make( + align_axis, opr::intl::get_megdnn_handle(cn))); + } else { + dval = std::make_shared(cn, aligned_shape, dtype); + } + return dval; +} + /* ================== ProfilerBase =================*/ std::string ProfilerBase::OperatorNodeRecord::to_string() const { auto str = ssprintf( diff --git a/src/gopt/impl/global_layout_transform/reformat_manager.cpp b/src/gopt/impl/global_layout_transform/reformat_manager.cpp index 38c301c31e267290a86864f43c82f406ff3a5f39..8d95a12ea91885f08e392afde282aa6d786a5322 100644 --- a/src/gopt/impl/global_layout_transform/reformat_manager.cpp +++ b/src/gopt/impl/global_layout_transform/reformat_manager.cpp @@ -249,7 +249,7 @@ ReformatManager::ReformatManager() { m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); } { - auto i = TensorFormats::KCRS, o = TensorFormats::GKRSCk4; + auto i = TensorFormats::GKCRS, o = TensorFormats::GKRSCk4; auto&& impl = [](const VarNodeArray& vars) { return opr::RelayoutFormat::make( vars[0], @@ -259,7 +259,7 @@ ReformatManager::ReformatManager() { m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); } { - auto i = TensorFormats::KCRS, o = TensorFormats::C1RSc4; + auto i = TensorFormats::C11RS, o = TensorFormats::C1RSc4; auto&& impl = [](const VarNodeArray& vars) { return opr::RelayoutFormat::make( vars[0], @@ -268,6 +268,21 @@ ReformatManager::ReformatManager() { }; m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); } + { + auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; + auto&& impl1 = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4) + .node(); + }; + m_cache.emplace(ReformatKey{i, o}, impl1); + auto&& impl2 = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4_NCHW) + .node(); + }; + m_cache.emplace(ReformatKey{o, i}, impl2); + } { auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; auto&& impl = [](const VarNodeArray& vars) { @@ -281,7 +296,7 @@ ReformatManager::ReformatManager() { auto i = TensorFormats::NHCWc4, o = TensorFormats::NCHW; auto&& impl = [](const VarNodeArray& vars) { return opr::RelayoutFormat::make( - vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I) + vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW) .node(); }; m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); @@ -346,6 +361,15 @@ ReformatManager::ReformatImpl ReformatManager::get(const ReformatKey& key) const return rst; } } + if (key.attribute == Attribute::IMAGE2D) { + auto key_ = key; + key_.input_dtype = DTypeEnum::Float32; + key_.output_dtype = DTypeEnum::Float32; + auto find = m_cache.find(key_); + if (find != m_cache.end()) { + return find->second; + } + } mgb_assert( !(key.attribute & Attribute::IMAGE2D) && !(key.attribute & Attribute::IC_SMALL)); @@ -682,7 +706,8 @@ TensorShape ReformatManager::make_aligned_weight_shape( auto target_shape = tensor_formats_to_named_tensor_shape(target_formats); for (size_t i = 0; i < target_shape.ndim; ++i) { auto name = target_shape[i].name(); - if ((name == Dimension::Name::K || name == Dimension::Name::N) && + if ((name == Dimension::Name::K || name == Dimension::Name::N || + (extra_formats == TensorFormats::NHCWc4 && name == Dimension::Name::C)) && target_shape[i].extent() == UNDETERMINED_EXTENT) { size_t out_channels = tshp[i] * target_shape[i].stride(); tshp[i] = divup(out_channels, out_channel_alignment) * diff --git a/src/gopt/impl/global_layout_transform/utils.h b/src/gopt/impl/global_layout_transform/utils.h index 64f25b101170ce970c5ca4ddb9d52f8f2aef1771..731caa0d6ece83e454189a1b225ed24a93d721ef 100644 --- a/src/gopt/impl/global_layout_transform/utils.h +++ b/src/gopt/impl/global_layout_transform/utils.h @@ -32,6 +32,7 @@ static inline const char* opr_format_to_string( cb(NCHW44); cb(NCHW88); cb(NCHW44_DOT); + cb(NHWCD4); default: mgb_assert( false, "Invalid opr format(got:%u)", @@ -63,6 +64,7 @@ static inline const char* config_id_to_string( cb(NCHW88_HYBRID); cb(NCHW44_DOT); cb(NCHW44_DOT_HYBRID); + cb(NHWCD4); default: mgb_assert( false, "Invalid config id(got:%u)", @@ -95,6 +97,8 @@ static inline TensorFormats opr_format_to_tensor_formats( return TensorFormats::NCHWc8; case OprFormat::NCHW44_DOT: return TensorFormats::NCHWc4; + case OprFormat::NHWCD4: + return TensorFormats::NHCWc4; default: mgb_throw( AssertionError, "format(%s) is not supported", diff --git a/src/gopt/include/megbrain/gopt/profiler.h b/src/gopt/include/megbrain/gopt/profiler.h index 0a05a011bedc15d0a3744117f72e6f1f440e35c3..93b8be72b93d47e8c717095cf8b9604f12a53857 100644 --- a/src/gopt/include/megbrain/gopt/profiler.h +++ b/src/gopt/include/megbrain/gopt/profiler.h @@ -202,6 +202,11 @@ protected: const ReformatKey& key) const; OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const; + std::shared_ptr create_device_tensor_helper( + const OprTensorFormatsConfiguration& config, const size_t inp_idx, + const VarNode* var, const TensorShape aligned_shape, + ReformatAttribute extra_attribute) const; + OprFootprint m_opr_footprint; float m_opr_threshold; /// a threshold, when the computation of the newly /// created operator that is built in some opr diff --git a/src/gopt/test/cache_data.h b/src/gopt/test/cache_data.h index 95501fba61488ef3ba4104c27d7de90859798736..6c2fb3e112447e0b171b7d4a804d0d4c711be914 100644 Binary files a/src/gopt/test/cache_data.h and b/src/gopt/test/cache_data.h differ diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index 74a29205bb9df6e15bd1fd566edc98b7237475d8..e18bfd9448d04d430e9949475fe74e20840aad02 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -336,6 +336,10 @@ cg::OperatorNodeBase::NodeProp* VolatileSharedDeviceTensor::do_make_node_prop() return ret; } +void VolatileSharedDeviceTensor::init_output_format() { + output(0)->format(get_dev_tensor().format()); +} + SymbolVar VolatileSharedDeviceTensor::make( ComputingGraph& graph, const std::shared_ptr& dev_data, const OperatorNodeConfig& config) { diff --git a/src/opr/include/megbrain/opr/io.h b/src/opr/include/megbrain/opr/io.h index 4e4118aaa8fd56cd61451014693553b6cabb330e..a1aba48df6712e4ff0cc2814351e7abc5b7fcaeb 100644 --- a/src/opr/include/megbrain/opr/io.h +++ b/src/opr/include/megbrain/opr/io.h @@ -337,6 +337,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( public: using Super::Super; + void init_output_format() override; + MGE_WIN_DECLSPEC_FUC static SymbolVar make( ComputingGraph& graph, const std::shared_ptr& dev_data, const OperatorNodeConfig& config = {});