convolution.cpp 54.9 KB
Newer Older
1 2 3 4 5 6 7
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"

using namespace megdnn;

namespace {
template <typename Param>
M
Megvii Engine Team 已提交
8 9 10
std::string get_errmsg(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        const Param& param) {
11 12 13 14
    MEGDNN_MARK_USED_VAR(src);
    MEGDNN_MARK_USED_VAR(filter);
    MEGDNN_MARK_USED_VAR(dst);
    return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
M
Megvii Engine Team 已提交
15
           megdnn_layout_msg(dst) + ", " + "is_nchw=" +
M
Megvii Engine Team 已提交
16 17 18 19
           std::to_string(param.format == param::Convolution::Format::NCHW) + ", " +
           "is_xcorr=" +
           std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " +
           "pad_h=" + std::to_string(param.pad_h) + ", " +
M
Megvii Engine Team 已提交
20 21 22 23 24
           "pad_w=" + std::to_string(param.pad_w) + ", " +
           "stride_h=" + std::to_string(param.stride_h) + ", " +
           "stride_w=" + std::to_string(param.stride_w) + ", " +
           "dilate_h=" + std::to_string(param.dilate_h) + ", " +
           "dilate_w=" + std::to_string(param.dilate_w);
25 26 27 28 29 30 31 32 33 34 35
}

template <typename Param, typename Param::Format>
uint32_t spatial_getter(uint32_t filter, const Param&) {
    return filter;
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nchw_nhwc(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
M
Megvii Engine Team 已提交
36 37
    megdnn_assert(
            param.format == Param::Format::NCHW || param.format == Param::Format::NHWC);
38 39 40 41 42 43 44 45 46 47 48
    auto img_ndim = src_ndim - 2;
    size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
    if (param.sparse == Param::Sparse::DENSE) {
        megdnn_assert(
                filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
49 50 51
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);

        // grp, oc, ic, dims[]
        ret.group = filter[0];
        flt_start = 1;
    }

    uint32_t ic_block_size = 1, oc_block_size = 1;
    if (param.format == Param::Format::NCHW) {
        // filter should be (oc, ic, fh, fw)
        flt_spatial_start = 2;
        ocpg_pos = 0;
        icpg_pos = 1;
    } else {
M
Megvii Engine Team 已提交
70 71
        megdnn_assert(
                param.format == Param::Format::NHWC, "invalid conv tensor format");
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        // filter should be (oc, fh, fw, ic)
        flt_spatial_start = 1;
        ocpg_pos = 0;
        icpg_pos = 3;
    }
    ret.spatial_ndim = src_ndim - 2;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    ret.ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
    ret.icpg = filter[flt_start + icpg_pos] * ic_block_size;
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
87 88 89
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
90 91
        ret.spatial[i] = spatial_getter<Param, Param::Format::NCHW>(
                filter[i + flt_start + flt_spatial_start], param);
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nhwcd4(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N H IC/4 W 4
     * Filter:
     *        OC/4, FH, FW, IC, 4 [dense]
     *        GROUP, OC/4, FH, FW, IC, 4 [group]
     *        GROUP/4, 1, FH, FW, 4 [chanwise]
     */
    megdnn_assert(param.format == Param::Format::NHWCD4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    bool is_chanwise = false;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
112 113 114 115 116
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
117 118 119 120
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
121 122 123
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        if (filter.ndim == img_ndim + 3 && filter[1] == 1) {
            is_chanwise = true;
            ret.group = filter[0] * 4;
        } else {
            ret.group = filter[0];
        }
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    if (is_chanwise) {
        ret.ocpg = 1;
        ret.icpg = 1;
    } else {
        ret.ocpg = filter[flt_start] * 4;
        ret.icpg = filter[flt_start + 3];
    }
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
152 153 154
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nhwcd4_dot(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N H IC/4 W 4
     * Filter:
     *        GROUP/4, 1, FH, FW, 4 [chanwise]
     *        OC/4, FH, FW, IC/4, 4, 4 [dense]
     *        GROUP, OC/4, FH, FW, IC/4, 4, 4 [group]
     */
    megdnn_assert(param.format == Param::Format::NHWCD4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    bool is_chanwise = false;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
176 177 178 179 180
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
181 182 183 184
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
185 186 187
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        if (filter.ndim == img_ndim + 3) {
            megdnn_assert(filter[1] == 1);
            is_chanwise = true;
            ret.group = filter[0] * 4;
        } else {
            ret.group = filter[0];
        }
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    if (is_chanwise) {
        ret.ocpg = 1;
        ret.icpg = 1;
    } else {
        ret.ocpg = filter[flt_start] * 4;
        ret.icpg = filter[flt_start + 3] * 4;
    }
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
217 218 219
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
220 221 222 223 224 225 226 227 228 229 230 231
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_nchwxx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N IC/pack_size, H, W, pack_size
     *
232 233 234 235 236 237 238
     ** NCHW44-DOT mode
     * filter:
     *        {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
     * [dense]
     *        {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
     *                  FH, FW, pack_size(OC), pack_size(IC)} [group]
     *
239
     * NCHW88 and NCHW44 mode
240 241 242 243 244 245 246 247 248 249
     * filter:
     *        {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
     * [dense]
     *        {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
     *                  FH, FW, pack_size(IC), pack_size(OC)} [group]
     *        {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
     *
     *
     */

M
Megvii Engine Team 已提交
250 251 252 253
    megdnn_assert(
            param.format == Param::Format::NCHW88 ||
            param.format == Param::Format::NCHW44 ||
            param.format == Param::Format::NCHW44_DOT);
254 255 256
    size_t img_ndim = 2;
    size_t flt_start = 0;
    size_t flt_spatial_start = 2;
257
    size_t pack_c_size = 0;
258 259 260
    if (param.sparse == Param::Sparse::DENSE) {
        if (filter.ndim == img_ndim + 4) {
            // oihw8i8o case
M
Megvii Engine Team 已提交
261 262 263 264 265 266 267
            megdnn_assert(
                    (filter[filter.ndim - 2] == pack_size &&
                     filter[filter.ndim - 1] == pack_size) ||
                            (filter[filter.ndim - 2] == 2 * pack_size &&
                             filter[filter.ndim - 1] == 2 * pack_size),
                    "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
                    filter[filter.ndim - 2], filter[filter.ndim - 1]);
268 269
            ret.group = 1;
            flt_start = 0;
270 271 272 273 274 275 276 277
            if (filter[filter.ndim - 2] == 2 * pack_size &&
                filter[filter.ndim - 1] == 2 * pack_size) {
                pack_c_size = 2 * pack_size;
            } else {
                pack_c_size = pack_size;
            }
            ret.ocpg = filter[flt_start] * pack_c_size;
            ret.icpg = filter[flt_start + 1] * pack_c_size;
278 279 280 281 282 283 284 285 286
        } else if (filter.ndim == img_ndim + 3) {
            // ohwi8o
            flt_start = 0;
            flt_spatial_start = 1;
            ret.group = 1;
            ret.ocpg = filter[flt_start] * pack_size;
            ret.icpg = filter[flt_start + 3];

        } else {
M
Megvii Engine Team 已提交
287
            megdnn_assert(0, "not support nchwxx filter dim = %zu", filter.ndim);
288 289
        }
    } else {
M
Megvii Engine Team 已提交
290 291 292
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
293 294 295
        flt_start = 1;
        auto filter_oc = filter[flt_start];
        auto filter_ic = filter[flt_start + 1];
296
        if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4)) {
297
            // Depthwise case goihw8g
M
Megvii Engine Team 已提交
298 299 300 301 302 303 304 305 306
            megdnn_assert(
                    filter.ndim == img_ndim + 4,
                    "bad filter ndim for group convolution: "
                    "spatial_ndim=%zu filter_ndim=%zu",
                    img_ndim, filter.ndim);
            megdnn_assert(
                    filter[filter.ndim - 1] == pack_size,
                    "last dim of filter must be %zu, but %zu", pack_size,
                    filter[filter.ndim - 1]);
307
            ret.group = filter[0] * pack_size;
308 309 310 311 312
            ret.ocpg = filter_oc;
            ret.icpg = filter_ic;

        } else {
            // norm group case goihw8i8o
M
Megvii Engine Team 已提交
313 314 315 316 317 318 319 320 321 322 323 324
            megdnn_assert(
                    filter.ndim == img_ndim + 5,
                    "bad filter ndim for group convolution: "
                    "spatial_ndim=%zu filter_ndim=%zu",
                    img_ndim, filter.ndim);
            megdnn_assert(
                    (filter[filter.ndim - 1] == pack_size &&
                     filter[filter.ndim - 2] == pack_size) ||
                            (filter[filter.ndim - 1] == 2 * pack_size &&
                             filter[filter.ndim - 2] == 2 * pack_size),
                    "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
                    filter[filter.ndim - 2], filter[filter.ndim - 1]);
325 326

            ret.group = filter[0];
327 328 329 330 331 332 333 334
            if (filter[filter.ndim - 2] == 2 * pack_size &&
                filter[filter.ndim - 1] == 2 * pack_size) {
                ret.ocpg = filter_oc * 2 * pack_size;
                ret.icpg = filter_ic * 2 * pack_size;
            } else {
                ret.ocpg = filter_oc * pack_size;
                ret.icpg = filter_ic * pack_size;
            }
335 336 337
        }
    }
    ret.spatial_ndim = 2;
M
Megvii Engine Team 已提交
338 339 340 341 342 343
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 5-dim "
            "for nchwxx; "
            "got input dim = %zu",
            src_ndim);
344 345 346

    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
347 348 349 350 351
        megdnn_assert(
                dilation[i] == 1,
                "NCHWXX has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
352
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
353 354 355 356 357 358 359 360 361 362 363 364 365 366
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_nchwx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N IC/pack_size, H, W, pack_size
     * filter:
     *        OC, IC/pack_size, FH, FW, pack_size [dense]
     *        GROUP, OC, IC/pack_size, FH, FW, pack_size [group]
     */
M
Megvii Engine Team 已提交
367 368 369 370 371 372 373 374 375
    megdnn_assert(
            param.format == Param::Format::NCHW4 ||
            param.format == Param::Format::NCHW8 ||
            param.format == Param::Format::NCHW32 ||
            param.format == Param::Format::NCHW4_NCHW ||
            param.format == Param::Format::NCHW4_NHWC ||
            param.format == Param::Format::NCHW4_NCHW32 ||
            param.format == Param::Format::NCHW32_NCHW4 ||
            param.format == Param::Format::NCHW64);
376 377 378
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 2;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
379 380 381 382 383
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
384 385 386 387
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
388 389 390 391 392 393 394 395
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
396 397 398 399
        ret.group = filter[0];
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
M
Megvii Engine Team 已提交
400 401 402 403 404 405
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 5-dim "
            "for nchw4; "
            "got input dim = %zu",
            src_ndim);
406 407 408 409
    ret.ocpg = filter[flt_start];
    ret.icpg = filter[flt_start + 1] * pack_size;
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
410 411 412 413 414
        megdnn_assert(
                dilation[i] == 1,
                "NCHW4 has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_chwnx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: IC / pack_size, H, W, N, pack_size
     * Filter:
     *        IC / pack_size, FH, FW, OC, pack_size [dense]
     *        GROUP, icpg / pack_size, FH, FW, ocpg, pack_size [group]
     *        not implemented [chanwise]
     */
    megdnn_assert(param.format == Param::Format::CHWN4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
435 436 437 438 439
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
440 441 442 443
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
444 445 446 447 448 449 450 451
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
452 453 454 455 456 457 458 459 460 461 462 463 464
        ret.group = filter[0];
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    ret.icpg = filter[flt_start] * pack_size;
    ret.ocpg = filter[flt_start + 3];
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
465 466 467 468 469
        megdnn_assert(
                dilation[i] == 1,
                "CHWNx has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
470 471 472 473 474 475 476 477 478
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

}  // namespace

namespace megdnn {
template <typename Parameter>
M
Megvii Engine Team 已提交
479 480
typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
        make_canonized_filter_meta(size_t src_ndim, const TensorLayout& filter) const {
481 482 483 484 485 486 487
    megdnn_assert_contiguous(filter);
    CanonizedFilterMeta ret;
    ret.dtype = filter.dtype;
    ret.format = param().format;
    if (param().mode == Mode::CONVOLUTION) {
        ret.should_flip = true;
    } else {
M
Megvii Engine Team 已提交
488
        megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode");
489 490 491 492 493 494 495 496 497 498 499 500
        ret.should_flip = false;
    }
    ret.stride[0] = param().stride_h;
    ret.stride[1] = param().stride_w;
    ret.padding[0] = param().pad_h;
    ret.padding[1] = param().pad_w;
    ret.dilation[0] = param().dilate_h;
    ret.dilation[1] = param().dilate_w;

    if (param().format == Param::Format::NHWCD4) {
        if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 ||
            filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
M
Megvii Engine Team 已提交
501 502
            make_canonized_filter_meta_nhwcd4_dot<Parameter>(
                    src_ndim, filter, param(), ret);
503
        } else {
M
Megvii Engine Team 已提交
504 505
            make_canonized_filter_meta_nhwcd4<Parameter>(
                    src_ndim, filter, param(), ret);
506
        }
M
Megvii Engine Team 已提交
507 508 509 510 511 512
    } else if (
            param().format == Param::Format::NCHW4 ||
            param().format == Param::Format::NCHW4_NCHW ||
            param().format == Param::Format::NCHW4_NHWC ||
            param().format == Param::Format::NCHW4_NCHW32) {
        make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret);
513
    } else if (param().format == Param::Format::NCHW8) {
M
Megvii Engine Team 已提交
514
        make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, param(), ret);
515
    } else if (param().format == Param::Format::NCHW88) {
M
Megvii Engine Team 已提交
516 517 518 519 520 521 522 523 524
        make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, param(), ret);
    } else if (
            param().format == Param::Format::NCHW44 ||
            param().format == Param::Format::NCHW44_DOT) {
        make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, param(), ret);
    } else if (
            param().format == Param::Format::NCHW32 ||
            param().format == Param::Format::NCHW32_NCHW4) {
        make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, param(), ret);
525
    } else if (param().format == Param::Format::CHWN4) {
M
Megvii Engine Team 已提交
526
        make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret);
527
    } else if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
528
        make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, param(), ret);
529
    } else {
M
Megvii Engine Team 已提交
530 531 532 533
        megdnn_assert(
                param().format == Param::Format::NHWC ||
                param().format == Param::Format::NCHW);
        make_canonized_filter_meta_nchw_nhwc<Parameter>(src_ndim, filter, param(), ret);
534 535 536 537 538
    }
    return ret;
}

template <typename Parameter>
M
Megvii Engine Team 已提交
539 540
void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
        DType src, DType filter, DType& dst) const {
541 542 543 544 545 546 547
    // The first one will be the default choice.
    SmallVector<DType> supported_dst_dtype;
    // We rely on megdnn_assert(src.enumv() == filter.enumv()) here.
    if (src.category() == DTypeCategory::FLOAT) {
        supported_dst_dtype.push_back(src);
    } else if (src.enumv() == DTypeEnum::Int8) {
        supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
M
Megvii Engine Team 已提交
548 549 550 551
    } else if (
            src.enumv() == DTypeEnum::QuantizedS8 ||
            src.enumv() == DTypeEnum::Quantized8Asymm ||
            src.enumv() == DTypeEnum::QuantizedS4 ||
552 553
            src.enumv() == DTypeEnum::Quantized4Asymm ||
            src.enumv() == DTypeEnum::QuantizedS1) {
M
Megvii Engine Team 已提交
554 555 556 557 558 559 560 561
        supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter)));
        bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() ||
                                        ((dst.enumv() == DTypeEnum::QuantizedS4 ||
                                          dst.enumv() == DTypeEnum::Quantized4Asymm) &&
                                         src.enumv() == DTypeEnum::QuantizedS8) ||
                                        ((src.enumv() == DTypeEnum::QuantizedS4 ||
                                          src.enumv() == DTypeEnum::Quantized4Asymm) &&
                                         dst.enumv() == DTypeEnum::QuantizedS8));
562
        if (cond_dst) {
563 564
            supported_dst_dtype.push_back(dst);
        }
565 566 567
        if (src.enumv() == DTypeEnum::QuantizedS8) {
            supported_dst_dtype.push_back(dtype::Float32());
        }
568 569 570
    } else if (src.enumv() == DTypeEnum::QuantizedS32) {
        //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
        megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
M
Megvii Engine Team 已提交
571 572 573 574 575
        supported_dst_dtype.push_back(dtype::QuantizedS8(
                src.param<dtype::QuantizedS32>().scale /
                filter.param<dtype::QuantizedS8>().scale));
    } else {
        megdnn_throw(ssprintf(
576 577 578 579 580 581 582 583 584
                "runtime does not support input / filter DType: %s x %s"
                "now support case list: FLOAT x FLOAT\n"
                "                       Int8 x Int8\n"
                "                       QuantizedS8 x QuantizedS8\n"
                "                       Quantized8Asymm x Quantized8Asymm\n"
                "                       QuantizedS4 x QuantizedS4\n"
                "                       Quantized4Asymm x Quantized4Asymm\n"
                "                       QuantizedS1 x QuantizedS1\n",
                src.name(), filter.name()));
585 586 587 588
    }
    if (!dst.valid()) {
        dst = supported_dst_dtype.at(0);
    } else {
589 590 591 592 593 594 595 596
        bool dst_supported = false;
        for (auto&& dt : supported_dst_dtype) {
            if (dtype_almost_equal(dt, dst)) {
                dst_supported = true;
                break;
            }
        }
        MEGDNN_MARK_USED_VAR(dst_supported);
M
Megvii Engine Team 已提交
597
        megdnn_assert(
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
                dst_supported,
                "runtime does not support Conv(%s, %s) -> %s"
                "now support case list: Conv(FLOAT x FLOAT) -> FLOAT\n"
                "                       Conv(Int8 x Int8) -> Int32\n"
                "                       Conv(QuantizedS8 x QuantizedS8) -> "
                "QuantizedS32\n"
                "                       Conv(Quantized8Asymm x Quantized8Asymm) -> "
                "Quantized32Asymm\n"
                "                       Conv(QuantizedS4 x QuantizedS4) -> "
                "QuantizedS32\n"
                "                       Conv(Quantized4Asymm x Quantized4Asymm) -> "
                "Quantized32Asymm\n"
                "                       Conv(QuantizedS1 x QuantizedS1) -> "
                "QuantizedS32\n",
                src.name(), filter.name(), dst.name());
613
    }
M
Megvii Engine Team 已提交
614 615 616
    megdnn_assert(
            (param().compute_mode == Param::ComputeMode::FLOAT32 ||
             param().compute_mode == Param::ComputeMode::DEFAULT)
617
#if !MEGDNN_DISABLE_FLOAT16
M
Megvii Engine Team 已提交
618 619
                    || src.enumv() == DTypeEnum::Float16 ||
                    src.enumv() == DTypeEnum::BFloat16
620
#endif
M
Megvii Engine Team 已提交
621 622 623
            ,
            "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
            "input / output.");
624 625 626
}

template <typename Parameter>
M
Megvii Engine Team 已提交
627 628 629 630
typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
        deduce_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                TensorLayout& dst) const {
631 632 633
    auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
    MEGDNN_MARK_USED_VAR(errmsg);
    megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
M
Megvii Engine Team 已提交
634 635 636 637 638
    megdnn_assert(
            ((src.dtype.enumv() == filter.dtype.enumv()) ||
             (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
              filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
            "%s", errmsg().c_str());
639 640 641
    check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
    size_t img_dim;
    if (param().format == Param::Format::NCHW ||
642
        param().format == Param::Format::NHWC) {
643
        img_dim = src.ndim - 2;
M
Megvii Engine Team 已提交
644 645 646
        megdnn_assert(
                filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, "%s",
                errmsg().c_str());
647 648

    } else {
M
Megvii Engine Team 已提交
649 650 651 652 653 654 655 656 657 658 659 660 661 662
        megdnn_assert(
                param().format == Param::Format::NHWCD4 ||
                param().format == Param::Format::NCHW4 ||
                param().format == Param::Format::NCHW4_NCHW ||
                param().format == Param::Format::NCHW4_NHWC ||
                param().format == Param::Format::NCHW4_NCHW32 ||
                param().format == Param::Format::NCHW44 ||
                param().format == Param::Format::NCHW44_DOT ||
                param().format == Param::Format::NCHW8 ||
                param().format == Param::Format::NCHW32 ||
                param().format == Param::Format::NCHW32_NCHW4 ||
                param().format == Param::Format::NCHW88 ||
                param().format == Param::Format::CHWN4 ||
                param().format == Param::Format::NCHW64);
663
        img_dim = src.ndim - 3;
664
        if ((param().format == Param::Format::NCHW88 ||
665
             param().format == Param::Format::NCHW44_DOT ||
666 667
             param().format == Param::Format::NCHW44) &&
            filter.ndim == 5) {
668 669
            img_dim = src.ndim - 2;
        }
M
Megvii Engine Team 已提交
670 671 672 673 674 675 676 677
        megdnn_assert(
                filter.ndim == img_dim + 3 ||
                        (filter.ndim == img_dim + 2 &&
                         (param().format == Param::Format::NCHW88 ||
                          param().format == Param::Format::NCHW44_DOT ||
                          param().format == Param::Format::NCHW44)) ||
                        filter.ndim == img_dim + 4 || filter.ndim == img_dim + 5,
                "%s", errmsg().c_str());
678 679 680
        if (param().format == Param::Format::NCHW4 ||
            param().format == Param::Format::NCHW4_NCHW ||
            param().format == Param::Format::NCHW4_NCHW32) {
M
Megvii Engine Team 已提交
681 682 683 684 685 686 687 688 689 690 691 692
            megdnn_assert(
                    src.ndim == 5 &&
                            (filter.ndim == 5 || filter.ndim == 6 ||
                             filter.ndim == 7) &&
                            src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
                    "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and "
                    "filter's ndim is "
                    "5 or 6, and "
                    "last shape "
                    "is 4 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
693 694 695 696
        }
        if (param().format == Param::Format::NCHW8) {
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
M
Megvii Engine Team 已提交
697
                            src[src.ndim - 1] == 8 && filter[filter.ndim - 1] == 8,
698 699 700 701 702
                    "NCHW8 require src and filter's ndim is 5 or 6, and last "
                    "shape is 8 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
        }
703 704
        if (param().format == Param::Format::NCHW32 ||
            param().format == Param::Format::NCHW32_NCHW4) {
M
Megvii Engine Team 已提交
705 706 707 708 709 710 711 712
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
                            src[src.ndim - 1] == 32 && filter[filter.ndim - 1] == 32,
                    "NCHW32/NCHW32_NCHW4 require src and filter's ndim "
                    "is 5 or 6, and last "
                    "shape is 32 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
713
        }
714
        if (param().format == Param::Format::NCHW88) {
M
Megvii Engine Team 已提交
715 716 717 718 719 720 721 722 723 724 725
            megdnn_assert(
                    (src.ndim == 4 && filter.ndim == 5 &&
                     filter[filter.ndim - 1] == 8) ||
                            (src.ndim == 5 &&
                             ((filter.ndim == 6 && filter[filter.ndim - 1] == 8) ||
                              (filter.ndim == 7 && filter[filter.ndim - 1] == 8 &&
                               filter[filter.ndim - 2] == 8)) &&
                             src[src.ndim - 1] == 8),
                    "NCHW88 require src ndim is 5 and filter's ndim is 6 "
                    ", and last shape two is 8 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
726
        }
727
        if (param().format == Param::Format::NCHW44 ||
728
            param().format == Param::Format::NCHW44_DOT) {
729 730
            //! support nchw44 filter change to 88 for int8 winogradf23_88 using
            //! MK8 mamtul
M
Megvii Engine Team 已提交
731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
            megdnn_assert(
                    (src.ndim == 4 && filter.ndim == 5 &&
                     filter[filter.ndim - 1] == 4) ||
                            (src.ndim == 5 &&
                             ((filter.ndim == 6 && (filter[filter.ndim - 1] == 4 ||
                                                    filter[filter.ndim - 1] == 8)) ||
                              (filter.ndim == 7 &&
                               (filter[filter.ndim - 1] == 4 ||
                                filter[filter.ndim - 1] == 8) &&
                               (filter[filter.ndim - 2] == 4 ||
                                filter[filter.ndim - 2] == 8))) &&
                             src[src.ndim - 1] == 4),
                    "NCHW44 require src ndim is 5 and filter's ndim is 6 "
                    ", and last shape two is 4 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
746
        }
747 748 749
        if (param().format == Param::Format::CHWN4) {
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
M
Megvii Engine Team 已提交
750
                            src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
751 752 753 754 755
                    "CHWN4 require src and filter's ndim is 5 or 6, and last "
                    "shape is 4 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
        }
756
        if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
757 758 759 760 761 762
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
                            src[src.ndim - 1] == 64 && filter[filter.ndim - 1] == 64,
                    "NCHW64 require src and filter's ndim is 5 or 6, and "
                    "last shape is 64 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
763
        }
764
    }
M
Megvii Engine Team 已提交
765
    megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported");
766 767
    auto cflt = make_canonized_filter_meta(src.ndim, filter);
    if (param().format == Param::Format::NCHW ||
768
        param().format == Param::Format::NHWC) {
769 770
        size_t src_or_dst_c_pos = 0;
        size_t src_or_dst_spatial_start = 0;
771
        if (param().format == Param::Format::NCHW) {
772 773 774
            src_or_dst_c_pos = 1;
            src_or_dst_spatial_start = 2;
        } else {
M
Megvii Engine Team 已提交
775
            megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
776 777 778
            src_or_dst_c_pos = 3;
            src_or_dst_spatial_start = 1;
        }
M
Megvii Engine Team 已提交
779
        megdnn_assert(
780 781 782 783
                cflt.icpg * cflt.group == src[src_or_dst_c_pos],
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str());
784 785 786 787 788 789 790 791 792
        dst.ndim = src.ndim;
        dst[0] = src[0];
        dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
        for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
            dst[i + src_or_dst_spatial_start] = infer_conv_shape(
                    src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
                    cflt.stride[i], cflt.padding[i]);
        }
    } else if (param().format == Param::Format::NCHW4) {
M
Megvii Engine Team 已提交
793 794 795 796
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
797 798 799 800
                cflt.icpg * cflt.group == src[1] * 4,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
801 802 803 804 805
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
806 807 808 809
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
810 811
        dst[4] = 4;
    } else if (param().format == Param::Format::NCHW8) {
M
Megvii Engine Team 已提交
812 813 814 815
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
816 817 818 819
                cflt.icpg * cflt.group == src[1] * 8,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[1] * 8, cflt.icpg * cflt.group, errmsg().c_str());
820 821 822 823 824
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 8 == 0);
        dst[1] = oc / 8;
M
Megvii Engine Team 已提交
825 826 827 828
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
829 830
        dst[4] = 8;
    } else if (param().format == Param::Format::NCHW32) {
M
Megvii Engine Team 已提交
831 832 833 834
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
835 836 837 838
                cflt.icpg * cflt.group == src[1] * 32,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
839 840 841 842 843
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 32 == 0);
        dst[1] = oc / 32;
M
Megvii Engine Team 已提交
844 845 846 847
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
848
        dst[4] = 32;
849
    } else if (param().format == Param::Format::NCHW88) {
850
        megdnn_assert(
851
                src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
852
                "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
853 854 855 856 857
        dst.ndim = 5;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 8 == 0);
        dst[1] = oc / 8;
M
Megvii Engine Team 已提交
858 859 860 861
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
862 863
        dst[4] = 8;
        if (cflt.group == 1) {
M
Megvii Engine Team 已提交
864 865 866
            megdnn_assert(
                    cflt.icpg * cflt.group == src[1] * 8 ||
                            (cflt.icpg * cflt.group == src[1]),
867 868 869 870 871
                    "group conv channel mismatch : input channel got %zu, and "
                    "filter channel got %u. More details about src, filter and dst : "
                    "\n%s",
                    src.ndim == 5 ? src[1] * 8 : src[1], cflt.icpg * cflt.group,
                    errmsg().c_str());
872 873
        }

874 875 876 877
    } else if (
            param().format == Param::Format::NCHW44 ||
            param().format == Param::Format::NCHW44_DOT) {
        megdnn_assert(
878
                src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
879
                "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
880 881 882 883 884
        dst.ndim = 5;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
885 886 887 888
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
889 890
        dst[4] = 4;
        if (cflt.group == 1) {
M
Megvii Engine Team 已提交
891 892 893
            megdnn_assert(
                    cflt.icpg * cflt.group == src[1] * 4 ||
                            (cflt.icpg * cflt.group == src[1]),
894 895 896 897 898
                    "group conv channel mismatch : input channel got %zu, and "
                    "filter channel got %u. More details about src, filter and dst : "
                    "\n%s",
                    src.ndim == 5 ? src[1] * 4 : src[1], cflt.icpg * cflt.group,
                    errmsg().c_str());
899
        }
900
    } else if (param().format == Param::Format::CHWN4) {
M
Megvii Engine Team 已提交
901 902 903 904
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
905 906 907 908
                cflt.icpg * cflt.group == src[0] * 4,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[0] * 4, cflt.icpg * cflt.group, errmsg().c_str());
909 910 911 912 913
        dst.ndim = src.ndim;
        dst[3] = src[3];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[0] = oc / 4;
M
Megvii Engine Team 已提交
914 915 916 917
        dst[1] = infer_conv_shape(
                src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
918
        dst[4] = 4;
919
    } else if (param().format == Param::Format::NCHW4_NCHW) {
M
Megvii Engine Team 已提交
920 921 922 923
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
924 925 926 927
                cflt.icpg * cflt.group == src[1] * 4,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
928 929 930 931
        dst.ndim = 4;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        dst[1] = oc;
M
Megvii Engine Team 已提交
932 933 934 935
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
936
    } else if (param().format == Param::Format::NCHW4_NHWC) {
M
Megvii Engine Team 已提交
937 938 939 940
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
941 942 943 944
                cflt.icpg * cflt.group == src[1] * 4,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
945 946
        dst.ndim = 4;
        dst[0] = src[0];
M
Megvii Engine Team 已提交
947 948 949 950
        dst[1] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[2] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
951 952
        auto oc = cflt.ocpg * cflt.group;
        dst[3] = oc;
953
    } else if (param().format == Param::Format::NCHW4_NCHW32) {
M
Megvii Engine Team 已提交
954 955 956 957
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
958 959 960 961
                cflt.icpg * cflt.group == src[1] * 4,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
962 963 964 965 966
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 32 == 0);
        dst[1] = oc / 32;
M
Megvii Engine Team 已提交
967 968 969 970
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
971 972
        dst[4] = 32;
    } else if (param().format == Param::Format::NCHW32_NCHW4) {
M
Megvii Engine Team 已提交
973 974 975 976
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
977 978 979 980
                cflt.icpg * cflt.group == src[1] * 32,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
981 982 983 984 985
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
986 987 988 989
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
990
        dst[4] = 4;
991
    } else if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
992 993 994 995
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
996 997 998 999
                cflt.icpg * cflt.group == src[1] * 64,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[1] * 64, cflt.icpg * cflt.group, errmsg().c_str());
1000 1001 1002 1003 1004
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 64 == 0);
        dst[1] = oc / 64;
M
Megvii Engine Team 已提交
1005 1006 1007 1008
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
1009
        dst[4] = 64;
1010 1011
    } else {
        megdnn_assert(param().format == Param::Format::NHWCD4);
M
Megvii Engine Team 已提交
1012 1013 1014 1015
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
1016 1017 1018 1019
                cflt.icpg * cflt.group == src[2] * 4,
                "group conv channel mismatch : input channel got %zu, and "
                "filter channel got %u. More details for src, filter and dst : \n%s",
                src[2] * 4, cflt.icpg * cflt.group, errmsg().c_str());
1020 1021 1022 1023 1024
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[2] = oc / 4;
M
Megvii Engine Team 已提交
1025 1026 1027 1028
        dst[1] = infer_conv_shape(
                src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
1029 1030 1031
        megdnn_assert(src[4] == 4);
        dst[4] = 4;
    }
M
Megvii Engine Team 已提交
1032
    if (!src.format.is_default() && !src.format.is_lowbit_aligned()) {  // propagate
M
Megvii Engine Team 已提交
1033 1034 1035
        dst.format = src.format;
    } else {  // determined by dtype
        dst.format = TensorFormat(dst.dtype);
M
Megvii Engine Team 已提交
1036
    }
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050
    dst.init_contiguous_stride();
    return cflt;
}

/**
 * \warning: An explicit specialization shall be declared in a namespace
 * enclosing the specialized template. An explicit specialization whose
 * declarator-id is not qualified shall be declared in the nearest enclosing
 * namespace of the template, or, if the namespace is inline (7.3.1), any
 * namespace from its enclosing namespace set.
 * refer to:
 * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace
 */
template <>
M
Megvii Engine Team 已提交
1051 1052 1053 1054 1055
ConvolutionBase<param::Convolution>::CanonizedFilterMeta ConvolutionBase<
        param::Convolution>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1056 1057
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1058 1059 1060 1061 1062 1063 1064 1065 1066
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

template <>
M
Megvii Engine Team 已提交
1067 1068 1069 1070
ConvolutionBase<param::ConvBias>::CanonizedFilterMeta ConvolutionBase<param::ConvBias>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1071 1072
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1073 1074 1075 1076 1077 1078 1079 1080 1081
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

template <>
M
Megvii Engine Team 已提交
1082 1083 1084 1085 1086
ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta ConvolutionBase<
        param::BatchConvBias>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1087 1088
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) {
    check_or_deduce_dtype_fwd(src, filter, dst);
}

M
Megvii Engine Team 已提交
1101 1102
void ConvolutionForward::deduce_layout(
        const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
1103 1104 1105 1106
    deduce_layout_fwd(src, filter, dst);
}

ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec(
M
Megvii Engine Team 已提交
1107 1108
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) {
1109
    auto ret = check_layout_fwd(src, filter, dst);
1110
    auto required_workspace_in_bytes =
1111
            get_workspace_in_bytes(src, filter, dst, preprocessed_filter);
1112 1113 1114 1115
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

M
Megvii Engine Team 已提交
1116 1117 1118
ConvolutionBackwardData::CanonizedFilterMeta ConvolutionBackwardData::check_exec(
        const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad,
        size_t workspace_in_bytes) {
1119 1120 1121 1122 1123 1124 1125 1126 1127
    auto grad_fwd = grad;
    auto filter_fwd = filter;
    auto diff_fwd = diff;

    std::swap(grad_fwd.dtype, diff_fwd.dtype);

    grad_fwd.init_contiguous_stride();
    diff_fwd.init_contiguous_stride();
    auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
M
Megvii Engine Team 已提交
1128
    auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
1129 1130 1131 1132
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

M
Megvii Engine Team 已提交
1133
void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad) {
1134 1135 1136 1137 1138 1139
    SmallVector<DType> supported_dst_dtype;
    if (filter.category() == diff.category() &&
        filter.category() == DTypeCategory::FLOAT) {
        supported_dst_dtype.push_back(filter);
    } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
        supported_dst_dtype.push_back(dtype::Int32());
M
Megvii Engine Team 已提交
1140 1141 1142 1143 1144 1145
    } else if (
            (filter.enumv() == DTypeEnum::QuantizedS8 &&
             diff.enumv() == DTypeEnum::QuantizedS8) ||
            (filter.enumv() == DTypeEnum::Quantized8Asymm &&
             diff.enumv() == DTypeEnum::Quantized8Asymm)) {
        supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff)));
1146 1147 1148 1149
        if (grad.valid() && grad.enumv() == diff.enumv()) {
            supported_dst_dtype.push_back(grad);
        }
    } else {
M
Megvii Engine Team 已提交
1150
        megdnn_throw(ssprintf(
1151 1152 1153 1154 1155 1156
                "runtime does not support input / diff DType: %s x %s"
                "now support case list: FLOAT x FLOAT\n"
                "                       Int8 x Int8\n"
                "                       QuantizedS8 x QuantizedS8\n"
                "                       Quantized8Asymm x Quantized8Asymm\n",
                filter.name(), diff.name()));
1157 1158 1159 1160
    }
    if (!grad.valid()) {
        grad = supported_dst_dtype.at(0);
    } else {
M
Megvii Engine Team 已提交
1161 1162
        megdnn_assert(
                vec_contains(supported_dst_dtype, grad),
1163 1164 1165 1166 1167 1168 1169 1170
                "runtime does not support ConvBwd(%s, %s) -> %s"
                "now support case list: ConvBwd(FLOAT x FLOAT) -> FLOAT\n"
                "                       ConvBwd(Int8 x Int8) -> Int32\n"
                "                       ConvBwd(QuantizedS8 x QuantizedS8) -> "
                "QuantizedS32\n"
                "                       ConvBwd(Quantized8Asymm x Quantized8Asymm) -> "
                "Quantized32Asymm\n",
                filter.name(), diff.name(), grad.name());
1171
    }
M
Megvii Engine Team 已提交
1172 1173
    megdnn_assert(
            param().compute_mode != Param::ComputeMode::FLOAT32
1174
#if !MEGDNN_DISABLE_FLOAT16
M
Megvii Engine Team 已提交
1175 1176
                    || filter.enumv() == DTypeEnum::Float16 ||
                    filter.enumv() == DTypeEnum::BFloat16
1177
#endif
M
Megvii Engine Team 已提交
1178 1179 1180
            ,
            "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
            "input / output.");
1181 1182
}

M
Megvii Engine Team 已提交
1183 1184
void ConvolutionBackwardData::deduce_layout(
        const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
1185 1186 1187 1188
    auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
    MEGDNN_MARK_USED_VAR(errmsg);
    megdnn_assert_contiguous(filter);
    megdnn_assert_contiguous(diff);
M
Megvii Engine Team 已提交
1189
    megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
1190 1191 1192 1193 1194 1195
    megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());

    deduce_dtype(filter.dtype, diff.dtype, grad.dtype);

    auto cflt = make_canonized_filter_meta(diff.ndim, filter);

M
Megvii Engine Team 已提交
1196
    auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
        MEGDNN_MARK_USED_VAR(errmsg);
        auto i = (out - 1) * stride + filter;
        megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
        return i - pad * 2;
    };

    if (param().format == Param::Format::NCHW ||
        param().format == Param::Format::NHWC) {
        size_t src_or_dst_c_pos = 0;
        size_t src_or_dst_spatial_start = 0;
        if (param().format == Param::Format::NCHW) {
            src_or_dst_c_pos = 1;
            src_or_dst_spatial_start = 2;
        } else {
M
Megvii Engine Team 已提交
1211
            megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
1212 1213 1214
            src_or_dst_c_pos = 3;
            src_or_dst_spatial_start = 1;
        }
M
Megvii Engine Team 已提交
1215 1216 1217
        megdnn_assert(
                cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s",
                errmsg().c_str());
1218 1219 1220 1221
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
        for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
1222 1223 1224
            grad[i + src_or_dst_spatial_start] =
                    deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
                           cflt.stride[i], cflt.padding[i]);
1225
        }
1226
    } else if (param().format == Param::Format::NCHW4) {
M
Megvii Engine Team 已提交
1227 1228 1229
        megdnn_assert(
                diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu",
                diff.ndim);
1230
        megdnn_assert(cflt.group == 1, "%s", errmsg().c_str());
M
Megvii Engine Team 已提交
1231
        megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str());
1232 1233 1234 1235 1236
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        auto ic = cflt.icpg * cflt.group;
        megdnn_assert(ic % 4 == 0);
        grad[1] = ic / 4;
M
Megvii Engine Team 已提交
1237 1238 1239 1240
        grad[2] = deduce(
                diff[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        grad[3] = deduce(
                diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
1241 1242
        megdnn_assert(diff[4] == 4);
        grad[4] = 4;
1243 1244
    } else {
        megdnn_assert(param().format == Param::Format::NHWCD4);
M
Megvii Engine Team 已提交
1245 1246 1247 1248
        megdnn_assert(
                diff.ndim == 5, "valid diff ndim for NHWCD4, expected=5, got=%zu",
                diff.ndim);
        megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", errmsg().c_str());
1249 1250 1251 1252 1253
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        auto ic = cflt.icpg * cflt.group;
        megdnn_assert(ic % 4 == 0);
        grad[2] = ic / 4;
M
Megvii Engine Team 已提交
1254 1255 1256 1257
        grad[1] = deduce(
                diff[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        grad[3] = deduce(
                diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
1258 1259 1260 1261 1262 1263 1264
        megdnn_assert(diff[4] == 4);
        grad[4] = 4;
    }
    grad.format = diff.format;
    grad.init_contiguous_stride();
}

M
Megvii Engine Team 已提交
1265 1266 1267 1268 1269 1270 1271 1272
ConvolutionBackwardFilter::CanonizedFilterMeta ConvolutionBackwardFilter::check_exec(
        const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
        size_t workspace_in_bytes) {
    megdnn_assert(
            src.dtype.category() == DTypeCategory::FLOAT &&
                    diff.dtype.category() == DTypeCategory::FLOAT &&
                    grad.dtype.category() == DTypeCategory::FLOAT,
            "only float type is supported for conv backward filter");
1273 1274 1275 1276 1277 1278
    auto src_fwd = src;
    auto diff_fwd = diff;

    src_fwd.init_contiguous_stride();
    diff_fwd.init_contiguous_stride();
    auto ret = check_layout_fwd(src_fwd, grad, diff_fwd);
1279 1280 1281 1282 1283 1284 1285 1286
    auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

}  // namespace megdnn

// vim: syntax=cpp.doxygen