convolution.cpp 51.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/convolution.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19
 */

#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"

using namespace megdnn;

namespace {
template <typename Param>
M
Megvii Engine Team 已提交
20 21 22
std::string get_errmsg(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        const Param& param) {
23 24 25 26
    MEGDNN_MARK_USED_VAR(src);
    MEGDNN_MARK_USED_VAR(filter);
    MEGDNN_MARK_USED_VAR(dst);
    return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
M
Megvii Engine Team 已提交
27
           megdnn_layout_msg(dst) + ", " + "is_nchw=" +
M
Megvii Engine Team 已提交
28 29 30 31
           std::to_string(param.format == param::Convolution::Format::NCHW) + ", " +
           "is_xcorr=" +
           std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " +
           "pad_h=" + std::to_string(param.pad_h) + ", " +
M
Megvii Engine Team 已提交
32 33 34 35 36
           "pad_w=" + std::to_string(param.pad_w) + ", " +
           "stride_h=" + std::to_string(param.stride_h) + ", " +
           "stride_w=" + std::to_string(param.stride_w) + ", " +
           "dilate_h=" + std::to_string(param.dilate_h) + ", " +
           "dilate_w=" + std::to_string(param.dilate_w);
37 38 39 40 41 42 43 44 45 46 47
}

template <typename Param, typename Param::Format>
uint32_t spatial_getter(uint32_t filter, const Param&) {
    return filter;
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nchw_nhwc(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
M
Megvii Engine Team 已提交
48 49
    megdnn_assert(
            param.format == Param::Format::NCHW || param.format == Param::Format::NHWC);
50 51 52 53 54 55 56 57 58 59 60
    auto img_ndim = src_ndim - 2;
    size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
    if (param.sparse == Param::Sparse::DENSE) {
        megdnn_assert(
                filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
61 62 63
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);

        // grp, oc, ic, dims[]
        ret.group = filter[0];
        flt_start = 1;
    }

    uint32_t ic_block_size = 1, oc_block_size = 1;
    if (param.format == Param::Format::NCHW) {
        // filter should be (oc, ic, fh, fw)
        flt_spatial_start = 2;
        ocpg_pos = 0;
        icpg_pos = 1;
    } else {
M
Megvii Engine Team 已提交
82 83
        megdnn_assert(
                param.format == Param::Format::NHWC, "invalid conv tensor format");
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        // filter should be (oc, fh, fw, ic)
        flt_spatial_start = 1;
        ocpg_pos = 0;
        icpg_pos = 3;
    }
    ret.spatial_ndim = src_ndim - 2;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    ret.ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
    ret.icpg = filter[flt_start + icpg_pos] * ic_block_size;
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
99 100 101
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
102 103
        ret.spatial[i] = spatial_getter<Param, Param::Format::NCHW>(
                filter[i + flt_start + flt_spatial_start], param);
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nhwcd4(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N H IC/4 W 4
     * Filter:
     *        OC/4, FH, FW, IC, 4 [dense]
     *        GROUP, OC/4, FH, FW, IC, 4 [group]
     *        GROUP/4, 1, FH, FW, 4 [chanwise]
     */
    megdnn_assert(param.format == Param::Format::NHWCD4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    bool is_chanwise = false;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
124 125 126 127 128
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
129 130 131 132
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
133 134 135
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        if (filter.ndim == img_ndim + 3 && filter[1] == 1) {
            is_chanwise = true;
            ret.group = filter[0] * 4;
        } else {
            ret.group = filter[0];
        }
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    if (is_chanwise) {
        ret.ocpg = 1;
        ret.icpg = 1;
    } else {
        ret.ocpg = filter[flt_start] * 4;
        ret.icpg = filter[flt_start + 3];
    }
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
164 165 166
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nhwcd4_dot(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N H IC/4 W 4
     * Filter:
     *        GROUP/4, 1, FH, FW, 4 [chanwise]
     *        OC/4, FH, FW, IC/4, 4, 4 [dense]
     *        GROUP, OC/4, FH, FW, IC/4, 4, 4 [group]
     */
    megdnn_assert(param.format == Param::Format::NHWCD4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    bool is_chanwise = false;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
188 189 190 191 192
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
193 194 195 196
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
197 198 199
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        if (filter.ndim == img_ndim + 3) {
            megdnn_assert(filter[1] == 1);
            is_chanwise = true;
            ret.group = filter[0] * 4;
        } else {
            ret.group = filter[0];
        }
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    if (is_chanwise) {
        ret.ocpg = 1;
        ret.icpg = 1;
    } else {
        ret.ocpg = filter[flt_start] * 4;
        ret.icpg = filter[flt_start + 3] * 4;
    }
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
229 230 231
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
232 233 234 235 236 237 238 239 240 241 242 243
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_nchwxx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N IC/pack_size, H, W, pack_size
     *
244 245 246 247 248 249 250
     ** NCHW44-DOT mode
     * filter:
     *        {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
     * [dense]
     *        {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
     *                  FH, FW, pack_size(OC), pack_size(IC)} [group]
     *
251
     * NCHW88 and NCHW44 mode
252 253 254 255 256 257 258 259 260 261
     * filter:
     *        {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
     * [dense]
     *        {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
     *                  FH, FW, pack_size(IC), pack_size(OC)} [group]
     *        {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
     *
     *
     */

M
Megvii Engine Team 已提交
262 263 264 265
    megdnn_assert(
            param.format == Param::Format::NCHW88 ||
            param.format == Param::Format::NCHW44 ||
            param.format == Param::Format::NCHW44_DOT);
266 267 268
    size_t img_ndim = 2;
    size_t flt_start = 0;
    size_t flt_spatial_start = 2;
269
    size_t pack_c_size = 0;
270 271 272
    if (param.sparse == Param::Sparse::DENSE) {
        if (filter.ndim == img_ndim + 4) {
            // oihw8i8o case
M
Megvii Engine Team 已提交
273 274 275 276 277 278 279
            megdnn_assert(
                    (filter[filter.ndim - 2] == pack_size &&
                     filter[filter.ndim - 1] == pack_size) ||
                            (filter[filter.ndim - 2] == 2 * pack_size &&
                             filter[filter.ndim - 1] == 2 * pack_size),
                    "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
                    filter[filter.ndim - 2], filter[filter.ndim - 1]);
280 281
            ret.group = 1;
            flt_start = 0;
282 283 284 285 286 287 288 289
            if (filter[filter.ndim - 2] == 2 * pack_size &&
                filter[filter.ndim - 1] == 2 * pack_size) {
                pack_c_size = 2 * pack_size;
            } else {
                pack_c_size = pack_size;
            }
            ret.ocpg = filter[flt_start] * pack_c_size;
            ret.icpg = filter[flt_start + 1] * pack_c_size;
290 291 292 293 294 295 296 297 298
        } else if (filter.ndim == img_ndim + 3) {
            // ohwi8o
            flt_start = 0;
            flt_spatial_start = 1;
            ret.group = 1;
            ret.ocpg = filter[flt_start] * pack_size;
            ret.icpg = filter[flt_start + 3];

        } else {
M
Megvii Engine Team 已提交
299
            megdnn_assert(0, "not support nchwxx filter dim = %zu", filter.ndim);
300 301
        }
    } else {
M
Megvii Engine Team 已提交
302 303 304
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
305 306 307
        flt_start = 1;
        auto filter_oc = filter[flt_start];
        auto filter_ic = filter[flt_start + 1];
308
        if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4)) {
309
            // Depthwise case goihw8g
M
Megvii Engine Team 已提交
310 311 312 313 314 315 316 317 318
            megdnn_assert(
                    filter.ndim == img_ndim + 4,
                    "bad filter ndim for group convolution: "
                    "spatial_ndim=%zu filter_ndim=%zu",
                    img_ndim, filter.ndim);
            megdnn_assert(
                    filter[filter.ndim - 1] == pack_size,
                    "last dim of filter must be %zu, but %zu", pack_size,
                    filter[filter.ndim - 1]);
319
            ret.group = filter[0] * pack_size;
320 321 322 323 324
            ret.ocpg = filter_oc;
            ret.icpg = filter_ic;

        } else {
            // norm group case goihw8i8o
M
Megvii Engine Team 已提交
325 326 327 328 329 330 331 332 333 334 335 336
            megdnn_assert(
                    filter.ndim == img_ndim + 5,
                    "bad filter ndim for group convolution: "
                    "spatial_ndim=%zu filter_ndim=%zu",
                    img_ndim, filter.ndim);
            megdnn_assert(
                    (filter[filter.ndim - 1] == pack_size &&
                     filter[filter.ndim - 2] == pack_size) ||
                            (filter[filter.ndim - 1] == 2 * pack_size &&
                             filter[filter.ndim - 2] == 2 * pack_size),
                    "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
                    filter[filter.ndim - 2], filter[filter.ndim - 1]);
337 338

            ret.group = filter[0];
339 340 341 342 343 344 345 346
            if (filter[filter.ndim - 2] == 2 * pack_size &&
                filter[filter.ndim - 1] == 2 * pack_size) {
                ret.ocpg = filter_oc * 2 * pack_size;
                ret.icpg = filter_ic * 2 * pack_size;
            } else {
                ret.ocpg = filter_oc * pack_size;
                ret.icpg = filter_ic * pack_size;
            }
347 348 349
        }
    }
    ret.spatial_ndim = 2;
M
Megvii Engine Team 已提交
350 351 352 353 354 355
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 5-dim "
            "for nchwxx; "
            "got input dim = %zu",
            src_ndim);
356 357 358

    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
359 360 361 362 363
        megdnn_assert(
                dilation[i] == 1,
                "NCHWXX has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
364
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
365 366 367 368 369 370 371 372 373 374 375 376 377 378
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_nchwx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N IC/pack_size, H, W, pack_size
     * filter:
     *        OC, IC/pack_size, FH, FW, pack_size [dense]
     *        GROUP, OC, IC/pack_size, FH, FW, pack_size [group]
     */
M
Megvii Engine Team 已提交
379 380 381 382 383 384 385 386 387
    megdnn_assert(
            param.format == Param::Format::NCHW4 ||
            param.format == Param::Format::NCHW8 ||
            param.format == Param::Format::NCHW32 ||
            param.format == Param::Format::NCHW4_NCHW ||
            param.format == Param::Format::NCHW4_NHWC ||
            param.format == Param::Format::NCHW4_NCHW32 ||
            param.format == Param::Format::NCHW32_NCHW4 ||
            param.format == Param::Format::NCHW64);
388 389 390
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 2;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
391 392 393 394 395
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
396 397 398 399
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
400 401 402 403 404 405 406 407
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
408 409 410 411
        ret.group = filter[0];
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
M
Megvii Engine Team 已提交
412 413 414 415 416 417
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 5-dim "
            "for nchw4; "
            "got input dim = %zu",
            src_ndim);
418 419 420 421
    ret.ocpg = filter[flt_start];
    ret.icpg = filter[flt_start + 1] * pack_size;
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
422 423 424 425 426
        megdnn_assert(
                dilation[i] == 1,
                "NCHW4 has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_chwnx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: IC / pack_size, H, W, N, pack_size
     * Filter:
     *        IC / pack_size, FH, FW, OC, pack_size [dense]
     *        GROUP, icpg / pack_size, FH, FW, ocpg, pack_size [group]
     *        not implemented [chanwise]
     */
    megdnn_assert(param.format == Param::Format::CHWN4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
447 448 449 450 451
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
452 453 454 455
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
456 457 458 459 460 461 462 463
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
464 465 466 467 468 469 470 471 472 473 474 475 476
        ret.group = filter[0];
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    ret.icpg = filter[flt_start] * pack_size;
    ret.ocpg = filter[flt_start + 3];
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
477 478 479 480 481
        megdnn_assert(
                dilation[i] == 1,
                "CHWNx has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
482 483 484 485 486 487 488 489 490
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

}  // namespace

namespace megdnn {
template <typename Parameter>
M
Megvii Engine Team 已提交
491 492
typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
        make_canonized_filter_meta(size_t src_ndim, const TensorLayout& filter) const {
493 494 495 496 497 498 499
    megdnn_assert_contiguous(filter);
    CanonizedFilterMeta ret;
    ret.dtype = filter.dtype;
    ret.format = param().format;
    if (param().mode == Mode::CONVOLUTION) {
        ret.should_flip = true;
    } else {
M
Megvii Engine Team 已提交
500
        megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode");
501 502 503 504 505 506 507 508 509 510 511 512
        ret.should_flip = false;
    }
    ret.stride[0] = param().stride_h;
    ret.stride[1] = param().stride_w;
    ret.padding[0] = param().pad_h;
    ret.padding[1] = param().pad_w;
    ret.dilation[0] = param().dilate_h;
    ret.dilation[1] = param().dilate_w;

    if (param().format == Param::Format::NHWCD4) {
        if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 ||
            filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
M
Megvii Engine Team 已提交
513 514
            make_canonized_filter_meta_nhwcd4_dot<Parameter>(
                    src_ndim, filter, param(), ret);
515
        } else {
M
Megvii Engine Team 已提交
516 517
            make_canonized_filter_meta_nhwcd4<Parameter>(
                    src_ndim, filter, param(), ret);
518
        }
M
Megvii Engine Team 已提交
519 520 521 522 523 524
    } else if (
            param().format == Param::Format::NCHW4 ||
            param().format == Param::Format::NCHW4_NCHW ||
            param().format == Param::Format::NCHW4_NHWC ||
            param().format == Param::Format::NCHW4_NCHW32) {
        make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret);
525
    } else if (param().format == Param::Format::NCHW8) {
M
Megvii Engine Team 已提交
526
        make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, param(), ret);
527
    } else if (param().format == Param::Format::NCHW88) {
M
Megvii Engine Team 已提交
528 529 530 531 532 533 534 535 536
        make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, param(), ret);
    } else if (
            param().format == Param::Format::NCHW44 ||
            param().format == Param::Format::NCHW44_DOT) {
        make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, param(), ret);
    } else if (
            param().format == Param::Format::NCHW32 ||
            param().format == Param::Format::NCHW32_NCHW4) {
        make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, param(), ret);
537
    } else if (param().format == Param::Format::CHWN4) {
M
Megvii Engine Team 已提交
538
        make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret);
539
    } else if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
540
        make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, param(), ret);
541
    } else {
M
Megvii Engine Team 已提交
542 543 544 545
        megdnn_assert(
                param().format == Param::Format::NHWC ||
                param().format == Param::Format::NCHW);
        make_canonized_filter_meta_nchw_nhwc<Parameter>(src_ndim, filter, param(), ret);
546 547 548 549 550
    }
    return ret;
}

template <typename Parameter>
M
Megvii Engine Team 已提交
551 552
void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
        DType src, DType filter, DType& dst) const {
553 554 555 556 557 558 559
    // The first one will be the default choice.
    SmallVector<DType> supported_dst_dtype;
    // We rely on megdnn_assert(src.enumv() == filter.enumv()) here.
    if (src.category() == DTypeCategory::FLOAT) {
        supported_dst_dtype.push_back(src);
    } else if (src.enumv() == DTypeEnum::Int8) {
        supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
M
Megvii Engine Team 已提交
560 561 562 563 564 565 566 567 568 569 570 571 572
    } else if (
            src.enumv() == DTypeEnum::QuantizedS8 ||
            src.enumv() == DTypeEnum::Quantized8Asymm ||
            src.enumv() == DTypeEnum::QuantizedS4 ||
            src.enumv() == DTypeEnum::Quantized4Asymm) {
        supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter)));
        bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() ||
                                        ((dst.enumv() == DTypeEnum::QuantizedS4 ||
                                          dst.enumv() == DTypeEnum::Quantized4Asymm) &&
                                         src.enumv() == DTypeEnum::QuantizedS8) ||
                                        ((src.enumv() == DTypeEnum::QuantizedS4 ||
                                          src.enumv() == DTypeEnum::Quantized4Asymm) &&
                                         dst.enumv() == DTypeEnum::QuantizedS8));
573
        if (cond_dst) {
574 575
            supported_dst_dtype.push_back(dst);
        }
576 577 578
        if (src.enumv() == DTypeEnum::QuantizedS8) {
            supported_dst_dtype.push_back(dtype::Float32());
        }
579 580 581
    } else if (src.enumv() == DTypeEnum::QuantizedS32) {
        //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
        megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
M
Megvii Engine Team 已提交
582 583 584 585 586 587 588
        supported_dst_dtype.push_back(dtype::QuantizedS8(
                src.param<dtype::QuantizedS32>().scale /
                filter.param<dtype::QuantizedS8>().scale));
    } else {
        megdnn_throw(ssprintf(
                "unsupported input / filter DType: %s x %s", src.name(),
                filter.name()));
589 590 591 592
    }
    if (!dst.valid()) {
        dst = supported_dst_dtype.at(0);
    } else {
593 594 595 596 597 598 599 600
        bool dst_supported = false;
        for (auto&& dt : supported_dst_dtype) {
            if (dtype_almost_equal(dt, dst)) {
                dst_supported = true;
                break;
            }
        }
        MEGDNN_MARK_USED_VAR(dst_supported);
M
Megvii Engine Team 已提交
601 602 603
        megdnn_assert(
                dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(),
                filter.name(), dst.name());
604
    }
M
Megvii Engine Team 已提交
605 606 607
    megdnn_assert(
            (param().compute_mode == Param::ComputeMode::FLOAT32 ||
             param().compute_mode == Param::ComputeMode::DEFAULT)
608
#if !MEGDNN_DISABLE_FLOAT16
M
Megvii Engine Team 已提交
609 610
                    || src.enumv() == DTypeEnum::Float16 ||
                    src.enumv() == DTypeEnum::BFloat16
611
#endif
M
Megvii Engine Team 已提交
612 613 614
            ,
            "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
            "input / output.");
615 616 617
}

template <typename Parameter>
M
Megvii Engine Team 已提交
618 619 620 621
typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
        deduce_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                TensorLayout& dst) const {
622 623 624
    auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
    MEGDNN_MARK_USED_VAR(errmsg);
    megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
M
Megvii Engine Team 已提交
625 626 627 628 629
    megdnn_assert(
            ((src.dtype.enumv() == filter.dtype.enumv()) ||
             (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
              filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
            "%s", errmsg().c_str());
630 631 632
    check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
    size_t img_dim;
    if (param().format == Param::Format::NCHW ||
633
        param().format == Param::Format::NHWC) {
634
        img_dim = src.ndim - 2;
M
Megvii Engine Team 已提交
635 636 637
        megdnn_assert(
                filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, "%s",
                errmsg().c_str());
638 639

    } else {
M
Megvii Engine Team 已提交
640 641 642 643 644 645 646 647 648 649 650 651 652 653
        megdnn_assert(
                param().format == Param::Format::NHWCD4 ||
                param().format == Param::Format::NCHW4 ||
                param().format == Param::Format::NCHW4_NCHW ||
                param().format == Param::Format::NCHW4_NHWC ||
                param().format == Param::Format::NCHW4_NCHW32 ||
                param().format == Param::Format::NCHW44 ||
                param().format == Param::Format::NCHW44_DOT ||
                param().format == Param::Format::NCHW8 ||
                param().format == Param::Format::NCHW32 ||
                param().format == Param::Format::NCHW32_NCHW4 ||
                param().format == Param::Format::NCHW88 ||
                param().format == Param::Format::CHWN4 ||
                param().format == Param::Format::NCHW64);
654
        img_dim = src.ndim - 3;
655
        if ((param().format == Param::Format::NCHW88 ||
656
             param().format == Param::Format::NCHW44_DOT ||
657 658
             param().format == Param::Format::NCHW44) &&
            filter.ndim == 5) {
659 660
            img_dim = src.ndim - 2;
        }
M
Megvii Engine Team 已提交
661 662 663 664 665 666 667 668
        megdnn_assert(
                filter.ndim == img_dim + 3 ||
                        (filter.ndim == img_dim + 2 &&
                         (param().format == Param::Format::NCHW88 ||
                          param().format == Param::Format::NCHW44_DOT ||
                          param().format == Param::Format::NCHW44)) ||
                        filter.ndim == img_dim + 4 || filter.ndim == img_dim + 5,
                "%s", errmsg().c_str());
669 670 671
        if (param().format == Param::Format::NCHW4 ||
            param().format == Param::Format::NCHW4_NCHW ||
            param().format == Param::Format::NCHW4_NCHW32) {
M
Megvii Engine Team 已提交
672 673 674 675 676 677 678 679 680 681 682 683
            megdnn_assert(
                    src.ndim == 5 &&
                            (filter.ndim == 5 || filter.ndim == 6 ||
                             filter.ndim == 7) &&
                            src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
                    "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and "
                    "filter's ndim is "
                    "5 or 6, and "
                    "last shape "
                    "is 4 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
684 685 686 687
        }
        if (param().format == Param::Format::NCHW8) {
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
M
Megvii Engine Team 已提交
688
                            src[src.ndim - 1] == 8 && filter[filter.ndim - 1] == 8,
689 690 691 692 693
                    "NCHW8 require src and filter's ndim is 5 or 6, and last "
                    "shape is 8 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
        }
694 695
        if (param().format == Param::Format::NCHW32 ||
            param().format == Param::Format::NCHW32_NCHW4) {
M
Megvii Engine Team 已提交
696 697 698 699 700 701 702 703
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
                            src[src.ndim - 1] == 32 && filter[filter.ndim - 1] == 32,
                    "NCHW32/NCHW32_NCHW4 require src and filter's ndim "
                    "is 5 or 6, and last "
                    "shape is 32 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
704
        }
705
        if (param().format == Param::Format::NCHW88) {
M
Megvii Engine Team 已提交
706 707 708 709 710 711 712 713 714 715 716
            megdnn_assert(
                    (src.ndim == 4 && filter.ndim == 5 &&
                     filter[filter.ndim - 1] == 8) ||
                            (src.ndim == 5 &&
                             ((filter.ndim == 6 && filter[filter.ndim - 1] == 8) ||
                              (filter.ndim == 7 && filter[filter.ndim - 1] == 8 &&
                               filter[filter.ndim - 2] == 8)) &&
                             src[src.ndim - 1] == 8),
                    "NCHW88 require src ndim is 5 and filter's ndim is 6 "
                    ", and last shape two is 8 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
717
        }
718
        if (param().format == Param::Format::NCHW44 ||
719
            param().format == Param::Format::NCHW44_DOT) {
720 721
            //! support nchw44 filter change to 88 for int8 winogradf23_88 using
            //! MK8 mamtul
M
Megvii Engine Team 已提交
722 723 724 725 726 727 728 729 730 731 732 733 734 735 736
            megdnn_assert(
                    (src.ndim == 4 && filter.ndim == 5 &&
                     filter[filter.ndim - 1] == 4) ||
                            (src.ndim == 5 &&
                             ((filter.ndim == 6 && (filter[filter.ndim - 1] == 4 ||
                                                    filter[filter.ndim - 1] == 8)) ||
                              (filter.ndim == 7 &&
                               (filter[filter.ndim - 1] == 4 ||
                                filter[filter.ndim - 1] == 8) &&
                               (filter[filter.ndim - 2] == 4 ||
                                filter[filter.ndim - 2] == 8))) &&
                             src[src.ndim - 1] == 4),
                    "NCHW44 require src ndim is 5 and filter's ndim is 6 "
                    ", and last shape two is 4 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
737
        }
738 739 740
        if (param().format == Param::Format::CHWN4) {
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
M
Megvii Engine Team 已提交
741
                            src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
742 743 744 745 746
                    "CHWN4 require src and filter's ndim is 5 or 6, and last "
                    "shape is 4 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
        }
747
        if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
748 749 750 751 752 753
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
                            src[src.ndim - 1] == 64 && filter[filter.ndim - 1] == 64,
                    "NCHW64 require src and filter's ndim is 5 or 6, and "
                    "last shape is 64 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
754
        }
755
    }
M
Megvii Engine Team 已提交
756
    megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported");
757 758
    auto cflt = make_canonized_filter_meta(src.ndim, filter);
    if (param().format == Param::Format::NCHW ||
759
        param().format == Param::Format::NHWC) {
760 761
        size_t src_or_dst_c_pos = 0;
        size_t src_or_dst_spatial_start = 0;
762
        if (param().format == Param::Format::NCHW) {
763 764 765
            src_or_dst_c_pos = 1;
            src_or_dst_spatial_start = 2;
        } else {
M
Megvii Engine Team 已提交
766
            megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
767 768 769
            src_or_dst_c_pos = 3;
            src_or_dst_spatial_start = 1;
        }
M
Megvii Engine Team 已提交
770 771 772
        megdnn_assert(
                cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
                errmsg().c_str());
773 774 775 776 777 778 779 780 781
        dst.ndim = src.ndim;
        dst[0] = src[0];
        dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
        for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
            dst[i + src_or_dst_spatial_start] = infer_conv_shape(
                    src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
                    cflt.stride[i], cflt.padding[i]);
        }
    } else if (param().format == Param::Format::NCHW4) {
M
Megvii Engine Team 已提交
782 783 784 785 786 787
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
788 789 790 791 792
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
793 794 795 796
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
797 798
        dst[4] = 4;
    } else if (param().format == Param::Format::NCHW8) {
M
Megvii Engine Team 已提交
799 800 801 802 803 804
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
805 806 807 808 809
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 8 == 0);
        dst[1] = oc / 8;
M
Megvii Engine Team 已提交
810 811 812 813
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
814 815
        dst[4] = 8;
    } else if (param().format == Param::Format::NCHW32) {
M
Megvii Engine Team 已提交
816 817 818 819 820 821
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
822 823 824 825 826
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 32 == 0);
        dst[1] = oc / 32;
M
Megvii Engine Team 已提交
827 828 829 830
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
831
        dst[4] = 32;
832
    } else if (param().format == Param::Format::NCHW88) {
M
Megvii Engine Team 已提交
833 834 835
        megdnn_assert(
                src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
                "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
836 837 838 839 840
        dst.ndim = 5;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 8 == 0);
        dst[1] = oc / 8;
M
Megvii Engine Team 已提交
841 842 843 844
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
845 846
        dst[4] = 8;
        if (cflt.group == 1) {
M
Megvii Engine Team 已提交
847 848 849 850
            megdnn_assert(
                    cflt.icpg * cflt.group == src[1] * 8 ||
                            (cflt.icpg * cflt.group == src[1]),
                    "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
851 852
        }

M
Megvii Engine Team 已提交
853 854 855 856 857 858
    } else if (
            param().format == Param::Format::NCHW44 ||
            param().format == Param::Format::NCHW44_DOT) {
        megdnn_assert(
                src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
                "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
859 860 861 862 863
        dst.ndim = 5;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
864 865 866 867
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
868 869
        dst[4] = 4;
        if (cflt.group == 1) {
M
Megvii Engine Team 已提交
870 871 872 873
            megdnn_assert(
                    cflt.icpg * cflt.group == src[1] * 4 ||
                            (cflt.icpg * cflt.group == src[1]),
                    "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
874
        }
875
    } else if (param().format == Param::Format::CHWN4) {
M
Megvii Engine Team 已提交
876 877 878 879 880 881
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
882 883 884 885 886
        dst.ndim = src.ndim;
        dst[3] = src[3];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[0] = oc / 4;
M
Megvii Engine Team 已提交
887 888 889 890
        dst[1] = infer_conv_shape(
                src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
891
        dst[4] = 4;
892
    } else if (param().format == Param::Format::NCHW4_NCHW) {
M
Megvii Engine Team 已提交
893 894 895 896 897 898
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
899 900 901 902
        dst.ndim = 4;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        dst[1] = oc;
M
Megvii Engine Team 已提交
903 904 905 906
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
907
    } else if (param().format == Param::Format::NCHW4_NHWC) {
M
Megvii Engine Team 已提交
908 909 910 911 912 913
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
914 915
        dst.ndim = 4;
        dst[0] = src[0];
M
Megvii Engine Team 已提交
916 917 918 919
        dst[1] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[2] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
920 921
        auto oc = cflt.ocpg * cflt.group;
        dst[3] = oc;
922
    } else if (param().format == Param::Format::NCHW4_NCHW32) {
M
Megvii Engine Team 已提交
923 924 925 926 927 928
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
929 930 931 932 933
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 32 == 0);
        dst[1] = oc / 32;
M
Megvii Engine Team 已提交
934 935 936 937
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
938 939
        dst[4] = 32;
    } else if (param().format == Param::Format::NCHW32_NCHW4) {
M
Megvii Engine Team 已提交
940 941 942 943 944 945
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
946 947 948 949 950
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
951 952 953 954
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
955
        dst[4] = 4;
956
    } else if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
957 958 959 960 961 962
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
963 964 965 966 967
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 64 == 0);
        dst[1] = oc / 64;
M
Megvii Engine Team 已提交
968 969 970 971
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
972
        dst[4] = 64;
973 974
    } else {
        megdnn_assert(param().format == Param::Format::NHWCD4);
M
Megvii Engine Team 已提交
975 976 977 978 979 980
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
981 982 983 984 985
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[2] = oc / 4;
M
Megvii Engine Team 已提交
986 987 988 989
        dst[1] = infer_conv_shape(
                src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
990 991 992
        megdnn_assert(src[4] == 4);
        dst[4] = 4;
    }
M
Megvii Engine Team 已提交
993
    if (!src.format.is_default() && !src.format.is_lowbit_aligned()) {  // propagate
M
Megvii Engine Team 已提交
994 995 996
        dst.format = src.format;
    } else {  // determined by dtype
        dst.format = TensorFormat(dst.dtype);
M
Megvii Engine Team 已提交
997
    }
998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
    dst.init_contiguous_stride();
    return cflt;
}

/**
 * \warning: An explicit specialization shall be declared in a namespace
 * enclosing the specialized template. An explicit specialization whose
 * declarator-id is not qualified shall be declared in the nearest enclosing
 * namespace of the template, or, if the namespace is inline (7.3.1), any
 * namespace from its enclosing namespace set.
 * refer to:
 * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace
 */
template <>
M
Megvii Engine Team 已提交
1012 1013 1014 1015 1016
ConvolutionBase<param::Convolution>::CanonizedFilterMeta ConvolutionBase<
        param::Convolution>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1017 1018
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1019 1020 1021 1022 1023 1024 1025 1026 1027
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

template <>
M
Megvii Engine Team 已提交
1028 1029 1030 1031
ConvolutionBase<param::ConvBias>::CanonizedFilterMeta ConvolutionBase<param::ConvBias>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1032 1033
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1034 1035 1036 1037 1038 1039 1040 1041 1042
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

template <>
M
Megvii Engine Team 已提交
1043 1044 1045 1046 1047
ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta ConvolutionBase<
        param::BatchConvBias>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1048 1049
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) {
    check_or_deduce_dtype_fwd(src, filter, dst);
}

M
Megvii Engine Team 已提交
1062 1063
void ConvolutionForward::deduce_layout(
        const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
1064 1065 1066 1067
    deduce_layout_fwd(src, filter, dst);
}

ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec(
M
Megvii Engine Team 已提交
1068 1069
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) {
1070
    auto ret = check_layout_fwd(src, filter, dst);
1071
    auto required_workspace_in_bytes =
1072
            get_workspace_in_bytes(src, filter, dst, preprocessed_filter);
1073 1074 1075 1076
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

M
Megvii Engine Team 已提交
1077 1078 1079
ConvolutionBackwardData::CanonizedFilterMeta ConvolutionBackwardData::check_exec(
        const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad,
        size_t workspace_in_bytes) {
1080 1081 1082 1083 1084 1085 1086 1087 1088
    auto grad_fwd = grad;
    auto filter_fwd = filter;
    auto diff_fwd = diff;

    std::swap(grad_fwd.dtype, diff_fwd.dtype);

    grad_fwd.init_contiguous_stride();
    diff_fwd.init_contiguous_stride();
    auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
M
Megvii Engine Team 已提交
1089
    auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
1090 1091 1092 1093
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

M
Megvii Engine Team 已提交
1094
void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad) {
1095 1096 1097 1098 1099 1100
    SmallVector<DType> supported_dst_dtype;
    if (filter.category() == diff.category() &&
        filter.category() == DTypeCategory::FLOAT) {
        supported_dst_dtype.push_back(filter);
    } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
        supported_dst_dtype.push_back(dtype::Int32());
M
Megvii Engine Team 已提交
1101 1102 1103 1104 1105 1106
    } else if (
            (filter.enumv() == DTypeEnum::QuantizedS8 &&
             diff.enumv() == DTypeEnum::QuantizedS8) ||
            (filter.enumv() == DTypeEnum::Quantized8Asymm &&
             diff.enumv() == DTypeEnum::Quantized8Asymm)) {
        supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff)));
1107 1108 1109 1110
        if (grad.valid() && grad.enumv() == diff.enumv()) {
            supported_dst_dtype.push_back(grad);
        }
    } else {
M
Megvii Engine Team 已提交
1111 1112
        megdnn_throw(ssprintf(
                "unsupported input / diff DType: %s x %s", filter.name(), diff.name()));
1113 1114 1115 1116
    }
    if (!grad.valid()) {
        grad = supported_dst_dtype.at(0);
    } else {
M
Megvii Engine Team 已提交
1117 1118 1119 1120
        megdnn_assert(
                vec_contains(supported_dst_dtype, grad),
                "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(),
                grad.name());
1121
    }
M
Megvii Engine Team 已提交
1122 1123
    megdnn_assert(
            param().compute_mode != Param::ComputeMode::FLOAT32
1124
#if !MEGDNN_DISABLE_FLOAT16
M
Megvii Engine Team 已提交
1125 1126
                    || filter.enumv() == DTypeEnum::Float16 ||
                    filter.enumv() == DTypeEnum::BFloat16
1127
#endif
M
Megvii Engine Team 已提交
1128 1129 1130
            ,
            "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
            "input / output.");
1131 1132
}

M
Megvii Engine Team 已提交
1133 1134
void ConvolutionBackwardData::deduce_layout(
        const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
1135 1136 1137 1138
    auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
    MEGDNN_MARK_USED_VAR(errmsg);
    megdnn_assert_contiguous(filter);
    megdnn_assert_contiguous(diff);
M
Megvii Engine Team 已提交
1139
    megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
1140 1141 1142 1143 1144 1145
    megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());

    deduce_dtype(filter.dtype, diff.dtype, grad.dtype);

    auto cflt = make_canonized_filter_meta(diff.ndim, filter);

M
Megvii Engine Team 已提交
1146
    auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
        MEGDNN_MARK_USED_VAR(errmsg);
        auto i = (out - 1) * stride + filter;
        megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
        return i - pad * 2;
    };

    if (param().format == Param::Format::NCHW ||
        param().format == Param::Format::NHWC) {
        size_t src_or_dst_c_pos = 0;
        size_t src_or_dst_spatial_start = 0;
        if (param().format == Param::Format::NCHW) {
            src_or_dst_c_pos = 1;
            src_or_dst_spatial_start = 2;
        } else {
M
Megvii Engine Team 已提交
1161
            megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
1162 1163 1164
            src_or_dst_c_pos = 3;
            src_or_dst_spatial_start = 1;
        }
M
Megvii Engine Team 已提交
1165 1166 1167
        megdnn_assert(
                cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s",
                errmsg().c_str());
1168 1169 1170 1171
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
        for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
1172 1173 1174
            grad[i + src_or_dst_spatial_start] =
                    deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
                           cflt.stride[i], cflt.padding[i]);
1175
        }
1176
    } else if (param().format == Param::Format::NCHW4) {
M
Megvii Engine Team 已提交
1177 1178 1179
        megdnn_assert(
                diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu",
                diff.ndim);
1180
        megdnn_assert(cflt.group == 1, "%s", errmsg().c_str());
M
Megvii Engine Team 已提交
1181
        megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str());
1182 1183 1184 1185 1186
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        auto ic = cflt.icpg * cflt.group;
        megdnn_assert(ic % 4 == 0);
        grad[1] = ic / 4;
M
Megvii Engine Team 已提交
1187 1188 1189 1190
        grad[2] = deduce(
                diff[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        grad[3] = deduce(
                diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
1191 1192
        megdnn_assert(diff[4] == 4);
        grad[4] = 4;
1193 1194
    } else {
        megdnn_assert(param().format == Param::Format::NHWCD4);
M
Megvii Engine Team 已提交
1195 1196 1197 1198
        megdnn_assert(
                diff.ndim == 5, "valid diff ndim for NHWCD4, expected=5, got=%zu",
                diff.ndim);
        megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", errmsg().c_str());
1199 1200 1201 1202 1203
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        auto ic = cflt.icpg * cflt.group;
        megdnn_assert(ic % 4 == 0);
        grad[2] = ic / 4;
M
Megvii Engine Team 已提交
1204 1205 1206 1207
        grad[1] = deduce(
                diff[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        grad[3] = deduce(
                diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
1208 1209 1210 1211 1212 1213 1214
        megdnn_assert(diff[4] == 4);
        grad[4] = 4;
    }
    grad.format = diff.format;
    grad.init_contiguous_stride();
}

M
Megvii Engine Team 已提交
1215 1216 1217 1218 1219 1220 1221 1222
ConvolutionBackwardFilter::CanonizedFilterMeta ConvolutionBackwardFilter::check_exec(
        const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
        size_t workspace_in_bytes) {
    megdnn_assert(
            src.dtype.category() == DTypeCategory::FLOAT &&
                    diff.dtype.category() == DTypeCategory::FLOAT &&
                    grad.dtype.category() == DTypeCategory::FLOAT,
            "only float type is supported for conv backward filter");
1223 1224 1225 1226 1227 1228
    auto src_fwd = src;
    auto diff_fwd = diff;

    src_fwd.init_contiguous_stride();
    diff_fwd.init_contiguous_stride();
    auto ret = check_layout_fwd(src_fwd, grad, diff_fwd);
1229 1230 1231 1232 1233 1234 1235 1236
    auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

}  // namespace megdnn

// vim: syntax=cpp.doxygen