You need to sign in or sign up before continuing.
提交 7b2a76d1 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(mgb): make conv handle noncontiguous tensors

GitOrigin-RevId: 86282709b3013ae44e13f74639a1c45a0dad97b0
上级 ca2828dd
...@@ -511,6 +511,12 @@ protected: ...@@ -511,6 +511,12 @@ protected:
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_in_bytes, const TensorLayout& dst, size_t workspace_in_bytes,
const PreprocessedFilter* preprocessed_filter); const PreprocessedFilter* preprocessed_filter);
CanonizedFilterMeta check_exec_allow_noncontiguous(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_in_bytes,
const PreprocessedFilter* preprocessed_filter);
}; };
using ConvBias = ConvBiasForward; using ConvBias = ConvBiasForward;
......
...@@ -11,30 +11,18 @@ ...@@ -11,30 +11,18 @@
*/ */
#include "src/common/conv_bias.h" #include "src/common/conv_bias.h"
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
namespace megdnn { namespace megdnn {
namespace {
void ConvBiasForward::deduce_dtype(DType src, DType filter, DType /* bias */, void do_check_exec_common(
DType /* z */, DType& dst) { ConvBiasForward* opr, const TensorLayout& src,
check_or_deduce_dtype_fwd(src, filter, dst); const TensorLayout& filter, const TensorLayout& bias,
} const TensorLayout& z, const TensorLayout& dst,
size_t workspace_in_bytes,
void ConvBiasForward::deduce_layout(const TensorLayout& src, const ConvBiasForward::PreprocessedFilter* preprocessed_filter) {
const TensorLayout& filter,
const TensorLayout& /* bias */,
const TensorLayout& /* z */,
TensorLayout& dst) {
deduce_layout_fwd(src, filter, dst);
}
ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_in_bytes,
const PreprocessedFilter* preprocessed_filter) {
megdnn_assert((src.dtype.enumv() == filter.dtype.enumv()) || megdnn_assert((src.dtype.enumv() == filter.dtype.enumv()) ||
(src.dtype.enumv() == DTypeEnum::Quantized4Asymm && (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter.dtype.enumv() == DTypeEnum::QuantizedS4)); filter.dtype.enumv() == DTypeEnum::QuantizedS4));
...@@ -52,9 +40,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -52,9 +40,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
} }
} }
auto ret = check_layout_fwd(src, filter, dst);
megdnn_assert_contiguous(bias); megdnn_assert_contiguous(bias);
auto required_workspace_in_bytes = get_workspace_in_bytes( auto required_workspace_in_bytes = opr->get_workspace_in_bytes(
src, filter, bias, z, dst, preprocessed_filter); src, filter, bias, z, dst, preprocessed_filter);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes,
"worksapce have size of %zu, but need %zu", "worksapce have size of %zu, but need %zu",
...@@ -68,55 +55,58 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -68,55 +55,58 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
return bias.eq_layout(dst); return bias.eq_layout(dst);
} }
}; };
if (check_eq(bias, dst)) if (check_eq(bias, dst)) {
return ret; return;
if (param().format == param::ConvBias::Format::NCHW || }
param().format == param::ConvBias::Format::NCHW4_NCHW) { if (opr->param().format == param::ConvBias::Format::NCHW ||
opr->param().format == param::ConvBias::Format::NCHW4_NCHW) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[3] == 1);
} else if (param().format == param::ConvBias::Format::NHWC) { } else if (opr->param().format == param::ConvBias::Format::NHWC) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == 1); megdnn_assert(bias.shape[1] == 1);
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s", megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
} else if (param().format == param::ConvBias::Format::NCHW4 || } else if (opr->param().format == param::ConvBias::Format::NCHW4 ||
param().format == param::ConvBias::Format::NCHW44 || opr->param().format == param::ConvBias::Format::NCHW44 ||
param().format == param::ConvBias::Format::NCHW44_DOT || opr->param().format == param::ConvBias::Format::NCHW44_DOT ||
param().format == param::ConvBias::Format::NCHW32_NCHW4) { opr->param().format ==
param::ConvBias::Format::NCHW32_NCHW4) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[3] == 1);
megdnn_assert(bias.shape[4] == 4); megdnn_assert(bias.shape[4] == 4);
} else if (param().format == param::ConvBias::Format::NCHW8 || } else if (opr->param().format == param::ConvBias::Format::NCHW8 ||
param().format == param::ConvBias::Format::NCHW88 ) { opr->param().format == param::ConvBias::Format::NCHW88) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[3] == 1);
megdnn_assert(bias.shape[4] == 8); megdnn_assert(bias.shape[4] == 8);
} else if (param().format == param::ConvBias::Format::NCHW32 || } else if (opr->param().format == param::ConvBias::Format::NCHW32 ||
param().format == param::ConvBias::Format::NCHW4_NCHW32) { opr->param().format ==
param::ConvBias::Format::NCHW4_NCHW32) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[3] == 1);
megdnn_assert(bias.shape[4] == 32); megdnn_assert(bias.shape[4] == 32);
} else if (param().format == param::ConvBias::Format::CHWN4) { } else if (opr->param().format == param::ConvBias::Format::CHWN4) {
megdnn_assert(bias.shape[0] == dst.shape[0], "bias:%s, dst:%s", megdnn_assert(bias.shape[0] == dst.shape[0], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
megdnn_assert(bias.shape[1] == 1); megdnn_assert(bias.shape[1] == 1);
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[3] == 1);
megdnn_assert(bias.shape[4] == 4); megdnn_assert(bias.shape[4] == 4);
} else if (param().format == param::ConvBias::Format::NCHW64) { } else if (opr->param().format == param::ConvBias::Format::NCHW64) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
...@@ -124,7 +114,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -124,7 +114,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[3] == 1);
megdnn_assert(bias.shape[4] == 64); megdnn_assert(bias.shape[4] == 64);
} else { } else {
megdnn_assert(param().format == param::ConvBias::Format::NHWCD4); megdnn_assert(opr->param().format ==
param::ConvBias::Format::NHWCD4);
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == 1); megdnn_assert(bias.shape[1] == 1);
megdnn_assert(bias.shape[2] == dst.shape[2], "bias:%s, dst:%s", megdnn_assert(bias.shape[2] == dst.shape[2], "bias:%s, dst:%s",
...@@ -135,11 +126,53 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -135,11 +126,53 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
} }
if (z.ndim != 0) { if (z.ndim != 0) {
megdnn_assert(param().format != param::ConvBias::Format::NCHW4_NCHW32); megdnn_assert(opr->param().format !=
megdnn_assert(param().format != param::ConvBias::Format::NCHW32_NCHW4); param::ConvBias::Format::NCHW4_NCHW32);
megdnn_assert(opr->param().format !=
param::ConvBias::Format::NCHW32_NCHW4);
megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); megdnn_assert(z.dtype.enumv() == dst.dtype.enumv());
megdnn_assert(z.eq_shape(dst)); megdnn_assert(z.eq_shape(dst));
} }
}
} // namespace
void ConvBiasForward::deduce_dtype(DType src, DType filter, DType /* bias */,
DType /* z */, DType& dst) {
check_or_deduce_dtype_fwd(src, filter, dst);
}
void ConvBiasForward::deduce_layout(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& /* bias */,
const TensorLayout& /* z */,
TensorLayout& dst) {
deduce_layout_fwd(src, filter, dst);
}
ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_in_bytes,
const PreprocessedFilter* preprocessed_filter) {
do_check_exec_common(this, src, filter, bias, z, dst, workspace_in_bytes,
preprocessed_filter);
auto ret = check_layout_fwd(src, filter, dst);
return ret;
}
ConvBiasForward::CanonizedFilterMeta
ConvBiasForward::check_exec_allow_noncontiguous(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_in_bytes,
const PreprocessedFilter* preprocessed_filter) {
do_check_exec_common(this, src, filter, bias, z, dst, workspace_in_bytes,
preprocessed_filter);
TensorLayout dst_expected;
dst_expected.dtype = dst.dtype;
auto ret = deduce_layout_fwd(src, filter, dst_expected);
megdnn_assert_eq_shape(dst_expected, dst);
return ret; return ret;
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "megdnn/handle.h" #include "megdnn/handle.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/general.h" #include "megdnn/oprs/general.h"
#include "megdnn/oprs/nn.h"
#include "megdnn/oprs/nn_int.h" #include "megdnn/oprs/nn_int.h"
#include "src/common/utils.h" #include "src/common/utils.h"
......
...@@ -595,8 +595,6 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -595,8 +595,6 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
TensorLayout& dst) const { TensorLayout& dst) const {
auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); }; auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
MEGDNN_MARK_USED_VAR(errmsg); MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(filter);
megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str()); megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
megdnn_assert(((src.dtype.enumv() == filter.dtype.enumv()) || megdnn_assert(((src.dtype.enumv() == filter.dtype.enumv()) ||
(src.dtype.enumv() == DTypeEnum::Quantized4Asymm && (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
...@@ -976,6 +974,8 @@ ConvolutionBase<param::Convolution>::CanonizedFilterMeta ...@@ -976,6 +974,8 @@ ConvolutionBase<param::Convolution>::CanonizedFilterMeta
ConvolutionBase<param::Convolution>::check_layout_fwd( ConvolutionBase<param::Convolution>::check_layout_fwd(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) const { const TensorLayout& dst) const {
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(filter);
TensorLayout dst_expected; TensorLayout dst_expected;
dst_expected.dtype = dst.dtype; dst_expected.dtype = dst.dtype;
...@@ -989,6 +989,8 @@ ConvolutionBase<param::ConvBias>::CanonizedFilterMeta ...@@ -989,6 +989,8 @@ ConvolutionBase<param::ConvBias>::CanonizedFilterMeta
ConvolutionBase<param::ConvBias>::check_layout_fwd( ConvolutionBase<param::ConvBias>::check_layout_fwd(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) const { const TensorLayout& dst) const {
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(filter);
TensorLayout dst_expected; TensorLayout dst_expected;
dst_expected.dtype = dst.dtype; dst_expected.dtype = dst.dtype;
...@@ -1002,6 +1004,8 @@ ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta ...@@ -1002,6 +1004,8 @@ ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta
ConvolutionBase<param::BatchConvBias>::check_layout_fwd( ConvolutionBase<param::BatchConvBias>::check_layout_fwd(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) const { const TensorLayout& dst) const {
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(filter);
TensorLayout dst_expected; TensorLayout dst_expected;
dst_expected.dtype = dst.dtype; dst_expected.dtype = dst.dtype;
......
...@@ -116,8 +116,9 @@ ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( ...@@ -116,8 +116,9 @@ ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& filter, const TensorLayout& bias,
const TensorLayout& z, const TensorLayout& dst, const TensorLayout& z, const TensorLayout& dst,
const PreprocessedFilter* preprocessed_filter) const PreprocessedFilter* preprocessed_filter)
: SizeArgs(o, src, filter, o->check_layout_fwd(src, filter, dst), bias, : SizeArgs(o, src, filter,
z, dst, preprocessed_filter) {} o->make_canonized_filter_meta(src.ndim, filter), bias, z,
dst, preprocessed_filter) {}
ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
ConvBiasForwardImpl* o, const TensorLayout& src, ConvBiasForwardImpl* o, const TensorLayout& src,
......
...@@ -75,8 +75,8 @@ ConvBiasForwardImpl::AlgoBatchedMatmul::get_subopr_list( ...@@ -75,8 +75,8 @@ ConvBiasForwardImpl::AlgoBatchedMatmul::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const { const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const ConvBiasForwardImpl* conv_bias_opr = const ConvBiasForwardImpl* conv_bias_opr =
static_cast<const ConvBiasForwardImpl*>(opr); static_cast<const ConvBiasForwardImpl*>(opr);
CanonizedFilterMeta fm = CanonizedFilterMeta fm = conv_bias_opr->make_canonized_filter_meta(
conv_bias_opr->check_layout_fwd(layouts[0], layouts[1], layouts[4]); layouts[0].ndim, layouts[1]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[4], auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[4],
conv_bias_opr); conv_bias_opr);
......
...@@ -20,6 +20,10 @@ using namespace conv_bias; ...@@ -20,6 +20,10 @@ using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoChanwise::is_available( bool ConvBiasForwardImpl::AlgoChanwise::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.src_layout->dtype == args.filter_layout->dtype && if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) { args.src_layout->dtype == dtype::BFloat16()) {
return false; return false;
......
...@@ -21,6 +21,10 @@ using namespace conv_bias; ...@@ -21,6 +21,10 @@ using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoChanwise8x8x32::is_available( bool ConvBiasForwardImpl::AlgoChanwise8x8x32::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.z_layout->ndim > 0) if (args.z_layout->ndim > 0)
return false; return false;
using NonlineMode = param::ConvBias::NonlineMode; using NonlineMode = param::ConvBias::NonlineMode;
......
...@@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) { ...@@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) {
bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available( bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.src_layout->dtype == args.filter_layout->dtype && if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) { args.src_layout->dtype == dtype::BFloat16()) {
return false; return false;
......
...@@ -63,6 +63,10 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout( ...@@ -63,6 +63,10 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout(
bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
auto&& param = args.opr->param(); auto&& param = args.opr->param();
bool is_format_ok = param.format == param::ConvBias::Format::NCHW; bool is_format_ok = param.format == param::ConvBias::Format::NCHW;
bool is_version_ok = CUDNN_VERSION >= 7500; bool is_version_ok = CUDNN_VERSION >= 7500;
......
...@@ -24,6 +24,10 @@ using namespace conv_bias; ...@@ -24,6 +24,10 @@ using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4)
......
...@@ -74,6 +74,10 @@ void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, ...@@ -74,6 +74,10 @@ void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter,
bool ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::is_available( bool ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.bias_layout->ndim <= 0) if (args.bias_layout->ndim <= 0)
return false; return false;
......
...@@ -62,6 +62,10 @@ void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, ...@@ -62,6 +62,10 @@ void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter,
bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::is_available( bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.bias_layout->ndim <= 0) if (args.bias_layout->ndim <= 0)
return false; return false;
......
...@@ -109,6 +109,10 @@ INST(PerChannelBiasVisitor); ...@@ -109,6 +109,10 @@ INST(PerChannelBiasVisitor);
bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter:: bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::
is_available(const SizeArgs& args) const { is_available(const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.bias_layout->ndim <= 0) if (args.bias_layout->ndim <= 0)
return false; return false;
......
...@@ -109,6 +109,10 @@ INST(PerChannelBiasVisitor); ...@@ -109,6 +109,10 @@ INST(PerChannelBiasVisitor);
bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth:: bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::
is_available(const SizeArgs& args) const { is_available(const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.bias_layout->ndim <= 0) if (args.bias_layout->ndim <= 0)
return false; return false;
......
...@@ -23,6 +23,10 @@ using namespace convolution; ...@@ -23,6 +23,10 @@ using namespace convolution;
#if CUDA_VERSION >= 10020 #if CUDA_VERSION >= 10020
bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.bias_layout->ndim <= 0) if (args.bias_layout->ndim <= 0)
return false; return false;
......
...@@ -20,6 +20,10 @@ using namespace cuda; ...@@ -20,6 +20,10 @@ using namespace cuda;
bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.bias_layout->ndim <= 0) if (args.bias_layout->ndim <= 0)
return false; return false;
......
...@@ -20,6 +20,10 @@ using namespace cuda; ...@@ -20,6 +20,10 @@ using namespace cuda;
#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000
bool ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::is_available( bool ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.bias_layout->ndim <= 0) if (args.bias_layout->ndim <= 0)
return false; return false;
......
...@@ -61,8 +61,8 @@ ConvBiasForwardImpl::AlgoMatmul::get_subopr_list( ...@@ -61,8 +61,8 @@ ConvBiasForwardImpl::AlgoMatmul::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const { const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const ConvBiasForwardImpl* conv_bias_opr = const ConvBiasForwardImpl* conv_bias_opr =
static_cast<const ConvBiasForwardImpl*>(opr); static_cast<const ConvBiasForwardImpl*>(opr);
CanonizedFilterMeta fm = CanonizedFilterMeta fm = conv_bias_opr->make_canonized_filter_meta(
conv_bias_opr->check_layout_fwd(layouts[0], layouts[1], layouts[4]); layouts[0].ndim, layouts[1]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[4], auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[4],
conv_bias_opr); conv_bias_opr);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/common/conv_bias.h"
#include "src/common/algo_chooser.h" #include "src/common/algo_chooser.h"
#include "src/cuda/cudnn_with_check.h" #include "src/cuda/cudnn_with_check.h"
...@@ -28,8 +29,9 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -28,8 +29,9 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_out dst, _megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter, const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, filter.layout, bias.layout, z.layout, dst.layout, check_exec_allow_noncontiguous(src.layout, filter.layout, bias.layout,
workspace.size, preprocessed_filter); z.layout, dst.layout, workspace.size,
preprocessed_filter);
AlgoBase::ExecArgs args(this, src, filter, bias, z, dst, workspace, AlgoBase::ExecArgs args(this, src, filter, bias, z, dst, workspace,
preprocessed_filter); preprocessed_filter);
auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout, auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout,
......
...@@ -87,6 +87,7 @@ public: ...@@ -87,6 +87,7 @@ public:
const AlgoAttribute& negative_attr) override; const AlgoAttribute& negative_attr) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
......
...@@ -25,6 +25,10 @@ using namespace activation_u4; ...@@ -25,6 +25,10 @@ using namespace activation_u4;
#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000
bool ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::is_available( bool ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
if (args.z_layout->ndim > 0) if (args.z_layout->ndim > 0)
return false; return false;
......
...@@ -233,9 +233,9 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -233,9 +233,9 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
dt_byte* workspace_ptr = workspace.raw_ptr; dt_byte* workspace_ptr = workspace.raw_ptr;
// ============================w * f + b================================ // ============================w * f + b================================
auto filter_meta = auto filter_meta = check_exec_allow_noncontiguous(
check_exec(src.layout, filter.layout, bias.layout, z.layout, src.layout, filter.layout, bias.layout, z.layout, dst.layout,
dst.layout, workspace.size, preprocessed_filter); workspace.size, preprocessed_filter);
auto sfb = dst; auto sfb = dst;
if (bias.layout.dtype.enumv() != dst.layout.dtype.enumv()) { if (bias.layout.dtype.enumv() != dst.layout.dtype.enumv()) {
// intermediate result // intermediate result
......
...@@ -749,6 +749,18 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_CUDNN_CONVOLUTION) { ...@@ -749,6 +749,18 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_CUDNN_CONVOLUTION) {
.set_param(arg.param) .set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}}); .execs({arg.src, arg.filter, arg.bias, {}, {}});
} }
//! noncontiguous case
{
param::ConvBias param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()},
{{}, {}, dtype::Float32()},
{{}, {}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
});
}
} }
TEST_F(CUDA, CONV_BIAS_FORWARD_INPLACE_MATMUL) { TEST_F(CUDA, CONV_BIAS_FORWARD_INPLACE_MATMUL) {
...@@ -791,6 +803,18 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_INPLACE_MATMUL) { ...@@ -791,6 +803,18 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_INPLACE_MATMUL) {
.execs({{2, 3, 3, 16}, {5, 3, 3, 3}, {1, 5, 1, 1}, {}, {}}) .execs({{2, 3, 3, 16}, {5, 3, 3, 3}, {1, 5, 1, 1}, {}, {}})
.execs({{2, 2, 8, 3}, {3, 2, 3, 3}, {1, 3, 1, 1}, {}, {}}); .execs({{2, 2, 8, 3}, {3, 2, 3, 3}, {1, 3, 1, 1}, {}, {}});
} }
//! noncontiguous case
{
param::ConvBias param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()},
{{}, {}, dtype::Float32()},
{{}, {}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
});
}
} }
TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL) { TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL) {
...@@ -835,6 +859,18 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL) { ...@@ -835,6 +859,18 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL) {
.execs({{2, 3, 3, 16}, {5, 3, 3, 3}, {1, 5, 1, 1}, {}, {}}) .execs({{2, 3, 3, 16}, {5, 3, 3, 3}, {1, 5, 1, 1}, {}, {}})
.execs({{2, 2, 8, 3}, {3, 2, 3, 3}, {1, 3, 1, 1}, {}, {}}); .execs({{2, 2, 8, 3}, {3, 2, 3, 3}, {1, 3, 1, 1}, {}, {}});
} }
//! noncontiguous case
{
param::ConvBias param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()},
{{}, {}, dtype::Float32()},
{{}, {}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
});
}
} }
TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_8x8x32) { TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_8x8x32) {
...@@ -880,6 +916,21 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_8x8x32) { ...@@ -880,6 +916,21 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_8x8x32) {
.execs({{2, 3, 16, 3}, {5, 3, 3, 3}, {1, 1, 1, 5}, {}, {}}) .execs({{2, 3, 16, 3}, {5, 3, 3, 3}, {1, 1, 1, 5}, {}, {}})
.execs({{2, 8, 3, 2}, {3, 3, 3, 2}, {1, 1, 1, 3}, {}, {}}); .execs({{2, 8, 3, 2}, {3, 3, 3, 2}, {1, 1, 1, 3}, {}, {}});
} }
//! noncontiguous case
{
param::ConvBias param;
param.pad_h = param.pad_w = 1;
param.format = param::ConvBias::Format::NHWC;
checker.set_param(param).execl(TensorLayoutArray{
{{2, 7, 7, 16}, {1568, 224, 32, 1}, dtype::QuantizedS8{1.2f}},
{{16, 3, 3, 16}, {144, 48, 16, 1}, dtype::QuantizedS8{1.3f}},
{{}, {}, dtype::QuantizedS32{1.2f * 1.3f}},
{{}, {}, dtype::QuantizedS8{1.1f}},
{{2, 7, 7, 16},
{1568, 224, 32, 1},
dtype::QuantizedS32{1.2f * 1.3f}},
});
}
} }
TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_NCHW4) { TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_NCHW4) {
...@@ -913,6 +964,21 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_NCHW4) { ...@@ -913,6 +964,21 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_NCHW4) {
checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 4, 1, 1, 4}, {}, {}}); checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 4, 1, 1, 4}, {}, {}});
checker.exec( checker.exec(
{{8, 64, 12, 12, 4}, {256, 64, 3, 3, 4}, {1, 64, 1, 1, 4}, {}, {}}); {{8, 64, 12, 12, 4}, {256, 64, 3, 3, 4}, {1, 64, 1, 1, 4}, {}, {}});
//! noncontiguous case
{
param::ConvBias param;
param.pad_h = param.pad_w = 1;
param.format = ConvBias::Param::Format::NCHW4;
checker.set_param(param).execl(TensorLayoutArray{
{{2, 4, 7, 7, 4}, {1568, 196, 28, 4, 1}, dtype::QuantizedS8{1.2f}},
{{16, 4, 3, 3, 4}, {144, 36, 12, 4, 1}, dtype::QuantizedS8{1.3f}},
{{}, {}, dtype::QuantizedS32{1.2f * 1.3f}},
{{}, {}, dtype::QuantizedS8{1.1f}},
{{2, 4, 7, 7, 4},
{1568, 196, 28, 4, 1},
dtype::QuantizedS32{1.2f * 1.3f}},
});
}
} }
TEST_F(CUDA, CONV_BIAS_FORWARD_BATCHED_MATMUL) { TEST_F(CUDA, CONV_BIAS_FORWARD_BATCHED_MATMUL) {
...@@ -939,6 +1005,17 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_BATCHED_MATMUL) { ...@@ -939,6 +1005,17 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_BATCHED_MATMUL) {
checker.set_param(arg.param); checker.set_param(arg.param);
checker.execs({arg.src, arg.filter, arg.bias, {}, {}}); checker.execs({arg.src, arg.filter, arg.bias, {}, {}});
} }
//! noncontiguous case
{
param::ConvBias param;
checker.set_param(param).execl(TensorLayoutArray{
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
{{16, 16, 1, 1}, {16, 1, 1, 1}, dtype::Float32()},
{{}, {}, dtype::Float32()},
{{}, {}, dtype::Float32()},
{{2, 16, 7, 7}, {784, 49, 7, 1}, dtype::Float32()},
});
}
} }
TEST_F(CUDA, CONV_BIAS_FORWARD_GROUP) { TEST_F(CUDA, CONV_BIAS_FORWARD_GROUP) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册