/** * \file dnn/src/cuda/conv_bias/conv_nchwqs8.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "src/common/conv_bias.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/cudnn_wrapper.h" #include "src/cuda/relayout_format/opr_impl.h" #include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace conv_bias; namespace { inline void deduce_reformat_layout(std::unique_ptr& relayout, const TensorLayout& src_layout, TensorLayout& dst_layout, RelayoutFormat::Param::Mode mode, const int oc = 0, const int group = 1) { if (src_layout.ndim > 0) { RelayoutFormat::Param trans_param; trans_param.mode = mode; trans_param.oc = oc; trans_param.group = group; relayout->param() = trans_param; relayout->deduce_layout(src_layout, dst_layout); } else { dst_layout = src_layout; } } } // namespace void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout( const SizeArgs& args, TensorLayout& inner_src_layout, TensorLayout& inner_weight_layout, TensorLayout& inner_dst_layout, TensorLayout& inner_bias_layout, TensorLayout& inner_z_layout) const { auto relayout_src = args.handle->create_operator(); deduce_reformat_layout(relayout_src, *args.src_layout, inner_src_layout, RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, args.filter_meta.group); deduce_reformat_layout(relayout_src, *args.filter_layout, inner_weight_layout, RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT); bool dst_float = args.dst_layout->dtype.enumv() == DTypeEnum::Float32; if (dst_float) { inner_dst_layout = *args.dst_layout; inner_bias_layout = *args.bias_layout; inner_z_layout = *args.z_layout; } else { deduce_reformat_layout(relayout_src, *args.dst_layout, inner_dst_layout, RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, args.filter_meta.group); deduce_reformat_layout(relayout_src, *args.bias_layout, inner_bias_layout, RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, args.filter_meta.group); deduce_reformat_layout(relayout_src, *args.z_layout, inner_z_layout, RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, args.filter_meta.group); } }; std::vector ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list( const TensorLayoutArray& layouts, const OperatorBase* opr) const { const ConvBiasForwardImpl* o = static_cast(opr); SizeArgs args(const_cast(o), layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], nullptr); TensorLayout inner_src_layout; TensorLayout inner_weight_layout; TensorLayout inner_dst_layout; TensorLayout inner_bias_layout; TensorLayout inner_z_layout; make_inner_layout(args, inner_src_layout, inner_weight_layout, inner_dst_layout, inner_bias_layout, inner_z_layout); Param inner_conv_param = o->param(); inner_conv_param.format = Param::Format::NCHW4; std::string param_str; Algorithm::serialize_write_pod(inner_conv_param, param_str); return {{Algorithm::OprType::CONVBIAS_FORWARD, param_str, {inner_src_layout, inner_weight_layout, inner_bias_layout, inner_z_layout, inner_dst_layout}}}; } bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( const SizeArgs& args) const { if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } auto&& param = args.opr->param(); bool is_format_ok = param.format == param::ConvBias::Format::NCHW; bool is_version_ok = CUDNN_VERSION >= 7500; bool is_dtype_ok = args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; bool is_bias_ok = args.bias_layout->ndim == 0 || (args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && args.bias_layout->shape[2] == 1 && args.bias_layout->shape[3] == 1); bool is_ok = is_format_ok && is_version_ok && is_dtype_ok && is_bias_ok; return is_ok; } WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( void* ptr, const SizeArgs& args) const { TensorLayout inner_src_layout; TensorLayout inner_weight_layout; TensorLayout inner_dst_layout; TensorLayout inner_bias_layout; TensorLayout inner_z_layout; make_inner_layout(args, inner_src_layout, inner_weight_layout, inner_dst_layout, inner_bias_layout, inner_z_layout); Param inner_conv_param = args.opr->param(); size_t ws_dst = 0, ws_bias = 0, ws_z = 0; if (args.dst_layout->dtype.enumv() == DTypeEnum::Float32) { inner_conv_param.format = Param::Format::NCHW4_NCHW; } else { inner_conv_param.format = Param::Format::NCHW4; ws_dst = inner_dst_layout.span().dist_byte(); ws_bias = inner_bias_layout.span().dist_byte(); ws_z = inner_z_layout.span().dist_byte(); } auto opr = args.handle->create_operator(); opr->param() = inner_conv_param; set_execution_policy(args.opr, opr.get()); return WorkspaceBundle( ptr, {inner_src_layout.span().dist_byte(), inner_weight_layout.span().dist_byte(), ws_dst, ws_bias, ws_z, opr->get_workspace_in_bytes(inner_src_layout, inner_weight_layout, inner_bias_layout, inner_z_layout, inner_dst_layout, nullptr)}); } size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_in_bytes( const SizeArgs& args) const { auto trans_bundle = get_workspace_bundle(nullptr, args); return trans_bundle.total_size_in_bytes(); } void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( const ExecArgs& args) const { auto relayout_nchw_nchw4 = args.handle->create_operator(); RelayoutFormat::Param in_trans; in_trans.mode = RelayoutFormat::Param::Mode::NCHW_NCHW4; in_trans.group = args.filter_meta.group; relayout_nchw_nchw4->param() = in_trans; auto relayout_weight = args.handle->create_operator(); RelayoutFormat::Param weight_trans; weight_trans.mode = RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT; relayout_weight->param() = weight_trans; auto relayout_nchw4_nchw = args.handle->create_operator(); RelayoutFormat::Param nchw4_nchw_trans; nchw4_nchw_trans.mode = RelayoutFormat::Param::Mode::NCHW4_NCHW; nchw4_nchw_trans.oc = args.dst_layout->shape[1]; nchw4_nchw_trans.group = args.filter_meta.group; relayout_nchw4_nchw->param() = nchw4_nchw_trans; auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); TensorLayout inner_src_layout; TensorLayout inner_weight_layout; TensorLayout inner_dst_layout; TensorLayout inner_bias_layout; TensorLayout inner_z_layout; make_inner_layout(args, inner_src_layout, inner_weight_layout, inner_dst_layout, inner_bias_layout, inner_z_layout); TensorND inner_src(bundle.get(0), inner_src_layout); TensorND inner_weight(bundle.get(1), inner_weight_layout); TensorND inner_dst(bundle.get(2), inner_dst_layout); TensorND inner_bias(bundle.get(3), inner_bias_layout); TensorND inner_z(bundle.get(4), inner_z_layout); bool dst_float = args.dst_layout->dtype.enumv() == DTypeEnum::Float32; Param inner_conv_param = args.opr->param(); inner_conv_param.format = dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4; auto inner_opr = args.handle->create_operator(); set_execution_policy(args.opr, inner_opr.get()); inner_opr->param() = inner_conv_param; relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {}); relayout_weight->exec(*args.filter_tensor, inner_weight, {}); if (dst_float) { inner_opr->exec(inner_src, inner_weight, *args.bias_tensor, *args.z_tensor, *args.dst_tensor, nullptr, Workspace((dt_byte*)bundle.get(5), bundle.get_size(5))); } else { if (inner_bias_layout.ndim > 0) { relayout_nchw_nchw4->exec(*args.bias_tensor, inner_bias, {}); } if (inner_z_layout.ndim > 0) { relayout_nchw_nchw4->exec(*args.z_tensor, inner_z, {}); } inner_opr->exec(inner_src, inner_weight, inner_bias, inner_z, inner_dst, nullptr, Workspace((dt_byte*)bundle.get(5), bundle.get_size(5))); relayout_nchw4_nchw->exec(inner_dst, *args.dst_tensor, {}); } } // vim: syntax=cpp.doxygen