提交 ba2f0c2e 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix cudnn_conv algo of conv_bias opr for fp16 add z cases

GitOrigin-RevId: b29b009de0803d0d8c4c6f1995ed10f2920652a8
上级 30976c23
......@@ -19,9 +19,6 @@ using namespace cuda;
using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) const {
if (args.z_layout->ndim > 0)
return false;
if (args.filter_meta.format != Param::Format::NCHW &&
args.filter_meta.format != Param::Format::NHWC) {
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
......@@ -75,6 +72,15 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle(
sizes.push_back(dst_layout.span().dist_byte());
}
if (args.z_layout->ndim > 0 &&
args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
auto z_layout = *args.z_layout;
z_layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype, z_layout.dtype);
sizes.push_back(z_layout.span().dist_byte());
}
SizeArgs conv_args = args;
conv_args.dst_layout = &dst_layout;
......@@ -129,6 +135,22 @@ void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const {
cudnnGetErrorString(status), conv_args.to_string().c_str());
}
if (args.z_layout->ndim > 0) {
auto z_tensor = *args.z_tensor;
if (args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
z_tensor.raw_ptr = bundle.get(2);
z_tensor.layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype,
z_tensor.layout.dtype);
auto typecvt = args.handle->create_operator<TypeCvt>();
typecvt->exec(*args.z_tensor, z_tensor);
}
auto add = args.handle->create_operator<ElemwiseForward>();
add->param().mode = Elemwise::Param::Mode::ADD;
add->exec({conv_dst_tensor, z_tensor}, conv_dst_tensor);
}
handle_bias_and_nonlinear(
args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor,
args.bias_tensor);
......
......@@ -71,11 +71,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
return false;
}
#if CUDNN_VERSION < 7605
if (args.src_layout->dtype.enumv() == DTypeEnum::Float16 &&
args.dst_layout->dtype.enumv() == DTypeEnum::Float16 &&
param.format == param::ConvBias::Format::NHWC) {
args.dst_layout->dtype.enumv() == DTypeEnum::Float16) {
return false;
}
#endif
#if CUDNN_MAJOR < 8
if (m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
......
......@@ -1293,6 +1293,56 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_TENSORCORE_INT8) {
}
}
TEST_F(CUDA, CONV_BIAS_ADD_Z_CUDNN_CONVOLUTION) {
using namespace conv_bias;
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ConvBiasForward::algo_name<ConvBias::DefaultParam>("CUDNN:Convolution", {})
.c_str()));
NormalRNG default_rng;
param::ConvBias param;
param.pad_h = param.pad_w = 1;
using Format = param::ConvBias::Format;
using NLMode = param::ConvBias::NonlineMode;
param.nonlineMode = NLMode::RELU;
auto c = [&](DType dt) {
param.format = Format::NCHW;
/// set epsilon to be 2e-3 to bypass low accuracy of winograd algorithm
float eps = 2e-3;
if (dt == dtype::Float16()) {
eps = 1e-2;
param.compute_mode = param::ConvBias::ComputeMode::FLOAT32;
}
checker.set_dtype(0, dt)
.set_dtype(1, dt)
.set_dtype(2, dt)
.set_dtype(3, dt)
.set_dtype(4, dt)
.set_rng(0, &default_rng)
.set_rng(1, &default_rng)
.set_rng(2, &default_rng)
.set_rng(3, &default_rng)
.set_epsilon(eps)
.set_param(param)
.execs({{16, 256, 7, 7},
{256, 256, 3, 3},
{1, 256, 1, 1},
{16, 256, 7, 7},
{}});
param.format = Format::NHWC;
checker.set_param(param).execs(
{{16, 7, 7, 256},
{256, 3, 3, 256},
{1, 1, 1, 256},
{16, 7, 7, 256},
{}});
};
c(dtype::Float32());
c(dtype::Float16());
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_TENSORCORE_INT8) {
require_compute_capability(7, 5);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册