提交 5e07e1e0 编写于 作者: M Megvii Engine Team

fix(dnn/falback): let cpu be able to execute int4 model

GitOrigin-RevId: 1a6b78f3b695aac0b9de346163e4b0f3cd1dc3fb
上级 0c098a8c
......@@ -431,7 +431,9 @@ ConvolutionImpl::AlgoDataType ConvolutionImpl::NCBKernSizeParam::deduce_algo_dat
}
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
} else if (src_type.enumv() == DTypeEnum::QuantizedS4) {
} else if (
src_type.enumv() == DTypeEnum::QuantizedS4 ||
src_type.enumv() == DTypeEnum::Quantized4Asymm) {
return ConvolutionImpl::AlgoDataType::QINT4x4x32;
} else {
megdnn_throw(ssprintf(
......@@ -477,7 +479,8 @@ void ConvolutionBackwardDataImpl::exec(
_megdnn_workspace workspace) {
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace);
}
......@@ -499,7 +502,8 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
filter, diff, grad);
......@@ -514,7 +518,8 @@ std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl
const TensorLayout& grad) {
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
filter, diff, grad);
......@@ -541,7 +546,8 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
const AlgoAttribute& negative_attr) {
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
filter, diff, grad, workspace_limit_in_bytes, positive_attr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册