convolution.cpp 51.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/convolution.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19
 */

#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"

using namespace megdnn;

namespace {
template <typename Param>
M
Megvii Engine Team 已提交
20 21 22
std::string get_errmsg(
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        const Param& param) {
23 24 25 26
    MEGDNN_MARK_USED_VAR(src);
    MEGDNN_MARK_USED_VAR(filter);
    MEGDNN_MARK_USED_VAR(dst);
    return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
M
Megvii Engine Team 已提交
27
           megdnn_layout_msg(dst) + ", " + "is_nchw=" +
M
Megvii Engine Team 已提交
28 29 30 31
           std::to_string(param.format == param::Convolution::Format::NCHW) + ", " +
           "is_xcorr=" +
           std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " +
           "pad_h=" + std::to_string(param.pad_h) + ", " +
M
Megvii Engine Team 已提交
32 33 34 35 36
           "pad_w=" + std::to_string(param.pad_w) + ", " +
           "stride_h=" + std::to_string(param.stride_h) + ", " +
           "stride_w=" + std::to_string(param.stride_w) + ", " +
           "dilate_h=" + std::to_string(param.dilate_h) + ", " +
           "dilate_w=" + std::to_string(param.dilate_w);
37 38 39 40 41 42 43 44 45 46 47
}

template <typename Param, typename Param::Format>
uint32_t spatial_getter(uint32_t filter, const Param&) {
    return filter;
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nchw_nhwc(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
M
Megvii Engine Team 已提交
48 49
    megdnn_assert(
            param.format == Param::Format::NCHW || param.format == Param::Format::NHWC);
50 51 52 53 54 55 56 57 58 59 60
    auto img_ndim = src_ndim - 2;
    size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
    if (param.sparse == Param::Sparse::DENSE) {
        megdnn_assert(
                filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
61 62 63
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);

        // grp, oc, ic, dims[]
        ret.group = filter[0];
        flt_start = 1;
    }

    uint32_t ic_block_size = 1, oc_block_size = 1;
    if (param.format == Param::Format::NCHW) {
        // filter should be (oc, ic, fh, fw)
        flt_spatial_start = 2;
        ocpg_pos = 0;
        icpg_pos = 1;
    } else {
M
Megvii Engine Team 已提交
82 83
        megdnn_assert(
                param.format == Param::Format::NHWC, "invalid conv tensor format");
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        // filter should be (oc, fh, fw, ic)
        flt_spatial_start = 1;
        ocpg_pos = 0;
        icpg_pos = 3;
    }
    ret.spatial_ndim = src_ndim - 2;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    ret.ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
    ret.icpg = filter[flt_start + icpg_pos] * ic_block_size;
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
99 100 101
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
102 103
        ret.spatial[i] = spatial_getter<Param, Param::Format::NCHW>(
                filter[i + flt_start + flt_spatial_start], param);
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nhwcd4(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N H IC/4 W 4
     * Filter:
     *        OC/4, FH, FW, IC, 4 [dense]
     *        GROUP, OC/4, FH, FW, IC, 4 [group]
     *        GROUP/4, 1, FH, FW, 4 [chanwise]
     */
    megdnn_assert(param.format == Param::Format::NHWCD4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    bool is_chanwise = false;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
124 125 126 127 128
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
129 130 131 132
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
133 134 135
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        if (filter.ndim == img_ndim + 3 && filter[1] == 1) {
            is_chanwise = true;
            ret.group = filter[0] * 4;
        } else {
            ret.group = filter[0];
        }
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    if (is_chanwise) {
        ret.ocpg = 1;
        ret.icpg = 1;
    } else {
        ret.ocpg = filter[flt_start] * 4;
        ret.icpg = filter[flt_start + 3];
    }
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
164 165 166
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <typename Parameter, typename Param>
void make_canonized_filter_meta_nhwcd4_dot(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N H IC/4 W 4
     * Filter:
     *        GROUP/4, 1, FH, FW, 4 [chanwise]
     *        OC/4, FH, FW, IC/4, 4, 4 [dense]
     *        GROUP, OC/4, FH, FW, IC/4, 4, 4 [group]
     */
    megdnn_assert(param.format == Param::Format::NHWCD4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    bool is_chanwise = false;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
188 189 190 191 192
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
193 194 195 196
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
197 198 199
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        megdnn_assert(
                filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
        if (filter.ndim == img_ndim + 3) {
            megdnn_assert(filter[1] == 1);
            is_chanwise = true;
            ret.group = filter[0] * 4;
        } else {
            ret.group = filter[0];
        }
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    if (is_chanwise) {
        ret.ocpg = 1;
        ret.icpg = 1;
    } else {
        ret.ocpg = filter[flt_start] * 4;
        ret.icpg = filter[flt_start + 3] * 4;
    }
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
229 230 231
        megdnn_assert(
                dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
                dilation[i]);
232 233 234 235 236 237 238 239 240 241 242 243
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_nchwxx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N IC/pack_size, H, W, pack_size
     *
244 245 246 247 248 249 250
     ** NCHW44-DOT mode
     * filter:
     *        {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
     * [dense]
     *        {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
     *                  FH, FW, pack_size(OC), pack_size(IC)} [group]
     *
251
     * NCHW88 and NCHW44 mode
252 253 254 255 256 257 258 259 260 261
     * filter:
     *        {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
     * [dense]
     *        {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
     *                  FH, FW, pack_size(IC), pack_size(OC)} [group]
     *        {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
     *
     *
     */

M
Megvii Engine Team 已提交
262 263 264 265
    megdnn_assert(
            param.format == Param::Format::NCHW88 ||
            param.format == Param::Format::NCHW44 ||
            param.format == Param::Format::NCHW44_DOT);
266 267 268
    size_t img_ndim = 2;
    size_t flt_start = 0;
    size_t flt_spatial_start = 2;
269
    size_t pack_c_size = 0;
270 271 272
    if (param.sparse == Param::Sparse::DENSE) {
        if (filter.ndim == img_ndim + 4) {
            // oihw8i8o case
M
Megvii Engine Team 已提交
273 274 275 276 277 278 279
            megdnn_assert(
                    (filter[filter.ndim - 2] == pack_size &&
                     filter[filter.ndim - 1] == pack_size) ||
                            (filter[filter.ndim - 2] == 2 * pack_size &&
                             filter[filter.ndim - 1] == 2 * pack_size),
                    "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
                    filter[filter.ndim - 2], filter[filter.ndim - 1]);
280 281
            ret.group = 1;
            flt_start = 0;
282 283 284 285 286 287 288 289
            if (filter[filter.ndim - 2] == 2 * pack_size &&
                filter[filter.ndim - 1] == 2 * pack_size) {
                pack_c_size = 2 * pack_size;
            } else {
                pack_c_size = pack_size;
            }
            ret.ocpg = filter[flt_start] * pack_c_size;
            ret.icpg = filter[flt_start + 1] * pack_c_size;
290 291 292 293 294 295 296 297 298
        } else if (filter.ndim == img_ndim + 3) {
            // ohwi8o
            flt_start = 0;
            flt_spatial_start = 1;
            ret.group = 1;
            ret.ocpg = filter[flt_start] * pack_size;
            ret.icpg = filter[flt_start + 3];

        } else {
M
Megvii Engine Team 已提交
299
            megdnn_assert(0, "not support nchwxx filter dim = %zu", filter.ndim);
300 301
        }
    } else {
M
Megvii Engine Team 已提交
302 303 304
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
305 306 307
        flt_start = 1;
        auto filter_oc = filter[flt_start];
        auto filter_ic = filter[flt_start + 1];
308
        if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4)) {
309
            // Depthwise case goihw8g
M
Megvii Engine Team 已提交
310 311 312 313 314 315 316 317 318
            megdnn_assert(
                    filter.ndim == img_ndim + 4,
                    "bad filter ndim for group convolution: "
                    "spatial_ndim=%zu filter_ndim=%zu",
                    img_ndim, filter.ndim);
            megdnn_assert(
                    filter[filter.ndim - 1] == pack_size,
                    "last dim of filter must be %zu, but %zu", pack_size,
                    filter[filter.ndim - 1]);
319
            ret.group = filter[0] * pack_size;
320 321 322 323 324
            ret.ocpg = filter_oc;
            ret.icpg = filter_ic;

        } else {
            // norm group case goihw8i8o
M
Megvii Engine Team 已提交
325 326 327 328 329 330 331 332 333 334 335 336
            megdnn_assert(
                    filter.ndim == img_ndim + 5,
                    "bad filter ndim for group convolution: "
                    "spatial_ndim=%zu filter_ndim=%zu",
                    img_ndim, filter.ndim);
            megdnn_assert(
                    (filter[filter.ndim - 1] == pack_size &&
                     filter[filter.ndim - 2] == pack_size) ||
                            (filter[filter.ndim - 1] == 2 * pack_size &&
                             filter[filter.ndim - 2] == 2 * pack_size),
                    "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
                    filter[filter.ndim - 2], filter[filter.ndim - 1]);
337 338

            ret.group = filter[0];
339 340 341 342 343 344 345 346
            if (filter[filter.ndim - 2] == 2 * pack_size &&
                filter[filter.ndim - 1] == 2 * pack_size) {
                ret.ocpg = filter_oc * 2 * pack_size;
                ret.icpg = filter_ic * 2 * pack_size;
            } else {
                ret.ocpg = filter_oc * pack_size;
                ret.icpg = filter_ic * pack_size;
            }
347 348 349
        }
    }
    ret.spatial_ndim = 2;
M
Megvii Engine Team 已提交
350 351 352 353 354 355
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 5-dim "
            "for nchwxx; "
            "got input dim = %zu",
            src_ndim);
356 357 358

    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
359 360 361 362 363
        megdnn_assert(
                dilation[i] == 1,
                "NCHWXX has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
364
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
365 366 367 368 369 370 371 372 373 374 375 376 377 378
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_nchwx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: N IC/pack_size, H, W, pack_size
     * filter:
     *        OC, IC/pack_size, FH, FW, pack_size [dense]
     *        GROUP, OC, IC/pack_size, FH, FW, pack_size [group]
     */
M
Megvii Engine Team 已提交
379 380 381 382 383 384 385 386 387
    megdnn_assert(
            param.format == Param::Format::NCHW4 ||
            param.format == Param::Format::NCHW8 ||
            param.format == Param::Format::NCHW32 ||
            param.format == Param::Format::NCHW4_NCHW ||
            param.format == Param::Format::NCHW4_NHWC ||
            param.format == Param::Format::NCHW4_NCHW32 ||
            param.format == Param::Format::NCHW32_NCHW4 ||
            param.format == Param::Format::NCHW64);
388 389 390
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 2;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
391 392 393 394 395
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
396 397 398 399
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
400 401 402 403 404 405 406 407
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
408 409 410 411
        ret.group = filter[0];
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
M
Megvii Engine Team 已提交
412 413 414 415 416 417
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 5-dim "
            "for nchw4; "
            "got input dim = %zu",
            src_ndim);
418 419 420 421
    ret.ocpg = filter[flt_start];
    ret.icpg = filter[flt_start + 1] * pack_size;
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
422 423 424 425 426
        megdnn_assert(
                dilation[i] == 1,
                "NCHW4 has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

template <size_t pack_size, typename Parameter, typename Param>
void make_canonized_filter_meta_chwnx(
        size_t src_ndim, const TensorLayout& filter, const Param& param,
        typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
    /**
     * input: IC / pack_size, H, W, N, pack_size
     * Filter:
     *        IC / pack_size, FH, FW, OC, pack_size [dense]
     *        GROUP, icpg / pack_size, FH, FW, ocpg, pack_size [group]
     *        not implemented [chanwise]
     */
    megdnn_assert(param.format == Param::Format::CHWN4);
    auto img_ndim = src_ndim - 3;
    size_t flt_start = 0, flt_spatial_start = 1;
    if (param.sparse == Param::Sparse::DENSE) {
M
Megvii Engine Team 已提交
447 448 449 450 451
        megdnn_assert(
                filter.ndim == img_ndim + 3,
                "bad filter ndim for dense convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
452 453 454 455
        // oc, ic, dims[]
        ret.group = 1;
        flt_start = 0;
    } else {
M
Megvii Engine Team 已提交
456 457 458 459 460 461 462 463
        megdnn_assert(
                param.sparse == Param::Sparse::GROUP,
                "invalid convolution sparse type");
        megdnn_assert(
                filter.ndim == img_ndim + 4,
                "bad filter ndim for group convolution: "
                "spatial_ndim=%zu filter_ndim=%zu",
                img_ndim, filter.ndim);
464 465 466 467 468 469 470 471 472 473 474 475 476
        ret.group = filter[0];
        flt_start = 1;
    }
    ret.spatial_ndim = src_ndim - 3;
    megdnn_assert(
            ret.spatial_ndim == 2,
            "only 2D convolution is supported, and input should be 4-dim; "
            "got input dim = %zu",
            src_ndim);
    ret.icpg = filter[flt_start] * pack_size;
    ret.ocpg = filter[flt_start + 3];
    auto dilation = ret.dilation;
    for (size_t i = 0; i < ret.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
477 478 479 480 481
        megdnn_assert(
                dilation[i] == 1,
                "CHWNx has invalid dilation on spatial dim %zu: %u, "
                "require to be 1",
                i, dilation[i]);
482 483 484 485 486 487 488 489 490
        ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
        ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
    }
}

}  // namespace

namespace megdnn {
template <typename Parameter>
M
Megvii Engine Team 已提交
491 492
typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
        make_canonized_filter_meta(size_t src_ndim, const TensorLayout& filter) const {
493 494 495 496 497 498 499
    megdnn_assert_contiguous(filter);
    CanonizedFilterMeta ret;
    ret.dtype = filter.dtype;
    ret.format = param().format;
    if (param().mode == Mode::CONVOLUTION) {
        ret.should_flip = true;
    } else {
M
Megvii Engine Team 已提交
500
        megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode");
501 502 503 504 505 506 507 508 509 510 511 512
        ret.should_flip = false;
    }
    ret.stride[0] = param().stride_h;
    ret.stride[1] = param().stride_w;
    ret.padding[0] = param().pad_h;
    ret.padding[1] = param().pad_w;
    ret.dilation[0] = param().dilate_h;
    ret.dilation[1] = param().dilate_w;

    if (param().format == Param::Format::NHWCD4) {
        if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 ||
            filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
M
Megvii Engine Team 已提交
513 514
            make_canonized_filter_meta_nhwcd4_dot<Parameter>(
                    src_ndim, filter, param(), ret);
515
        } else {
M
Megvii Engine Team 已提交
516 517
            make_canonized_filter_meta_nhwcd4<Parameter>(
                    src_ndim, filter, param(), ret);
518
        }
M
Megvii Engine Team 已提交
519 520 521 522 523 524
    } else if (
            param().format == Param::Format::NCHW4 ||
            param().format == Param::Format::NCHW4_NCHW ||
            param().format == Param::Format::NCHW4_NHWC ||
            param().format == Param::Format::NCHW4_NCHW32) {
        make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret);
525
    } else if (param().format == Param::Format::NCHW8) {
M
Megvii Engine Team 已提交
526
        make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, param(), ret);
527
    } else if (param().format == Param::Format::NCHW88) {
M
Megvii Engine Team 已提交
528 529 530 531 532 533 534 535 536
        make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, param(), ret);
    } else if (
            param().format == Param::Format::NCHW44 ||
            param().format == Param::Format::NCHW44_DOT) {
        make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, param(), ret);
    } else if (
            param().format == Param::Format::NCHW32 ||
            param().format == Param::Format::NCHW32_NCHW4) {
        make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, param(), ret);
537
    } else if (param().format == Param::Format::CHWN4) {
M
Megvii Engine Team 已提交
538
        make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret);
539
    } else if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
540
        make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, param(), ret);
541
    } else {
M
Megvii Engine Team 已提交
542 543 544 545
        megdnn_assert(
                param().format == Param::Format::NHWC ||
                param().format == Param::Format::NCHW);
        make_canonized_filter_meta_nchw_nhwc<Parameter>(src_ndim, filter, param(), ret);
546 547 548 549 550
    }
    return ret;
}

template <typename Parameter>
M
Megvii Engine Team 已提交
551 552
void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
        DType src, DType filter, DType& dst) const {
553 554 555 556 557 558 559
    // The first one will be the default choice.
    SmallVector<DType> supported_dst_dtype;
    // We rely on megdnn_assert(src.enumv() == filter.enumv()) here.
    if (src.category() == DTypeCategory::FLOAT) {
        supported_dst_dtype.push_back(src);
    } else if (src.enumv() == DTypeEnum::Int8) {
        supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
M
Megvii Engine Team 已提交
560 561 562 563
    } else if (
            src.enumv() == DTypeEnum::QuantizedS8 ||
            src.enumv() == DTypeEnum::Quantized8Asymm ||
            src.enumv() == DTypeEnum::QuantizedS4 ||
564 565
            src.enumv() == DTypeEnum::Quantized4Asymm ||
            src.enumv() == DTypeEnum::QuantizedS1) {
M
Megvii Engine Team 已提交
566 567 568 569 570 571 572 573
        supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter)));
        bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() ||
                                        ((dst.enumv() == DTypeEnum::QuantizedS4 ||
                                          dst.enumv() == DTypeEnum::Quantized4Asymm) &&
                                         src.enumv() == DTypeEnum::QuantizedS8) ||
                                        ((src.enumv() == DTypeEnum::QuantizedS4 ||
                                          src.enumv() == DTypeEnum::Quantized4Asymm) &&
                                         dst.enumv() == DTypeEnum::QuantizedS8));
574
        if (cond_dst) {
575 576
            supported_dst_dtype.push_back(dst);
        }
577 578 579
        if (src.enumv() == DTypeEnum::QuantizedS8) {
            supported_dst_dtype.push_back(dtype::Float32());
        }
580 581 582
    } else if (src.enumv() == DTypeEnum::QuantizedS32) {
        //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
        megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
M
Megvii Engine Team 已提交
583 584 585 586 587 588 589
        supported_dst_dtype.push_back(dtype::QuantizedS8(
                src.param<dtype::QuantizedS32>().scale /
                filter.param<dtype::QuantizedS8>().scale));
    } else {
        megdnn_throw(ssprintf(
                "unsupported input / filter DType: %s x %s", src.name(),
                filter.name()));
590 591 592 593
    }
    if (!dst.valid()) {
        dst = supported_dst_dtype.at(0);
    } else {
594 595 596 597 598 599 600 601
        bool dst_supported = false;
        for (auto&& dt : supported_dst_dtype) {
            if (dtype_almost_equal(dt, dst)) {
                dst_supported = true;
                break;
            }
        }
        MEGDNN_MARK_USED_VAR(dst_supported);
M
Megvii Engine Team 已提交
602 603 604
        megdnn_assert(
                dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(),
                filter.name(), dst.name());
605
    }
M
Megvii Engine Team 已提交
606 607 608
    megdnn_assert(
            (param().compute_mode == Param::ComputeMode::FLOAT32 ||
             param().compute_mode == Param::ComputeMode::DEFAULT)
609
#if !MEGDNN_DISABLE_FLOAT16
M
Megvii Engine Team 已提交
610 611
                    || src.enumv() == DTypeEnum::Float16 ||
                    src.enumv() == DTypeEnum::BFloat16
612
#endif
M
Megvii Engine Team 已提交
613 614 615
            ,
            "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
            "input / output.");
616 617 618
}

template <typename Parameter>
M
Megvii Engine Team 已提交
619 620 621 622
typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
        deduce_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                TensorLayout& dst) const {
623 624 625
    auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
    MEGDNN_MARK_USED_VAR(errmsg);
    megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
M
Megvii Engine Team 已提交
626 627 628 629 630
    megdnn_assert(
            ((src.dtype.enumv() == filter.dtype.enumv()) ||
             (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
              filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
            "%s", errmsg().c_str());
631 632 633
    check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
    size_t img_dim;
    if (param().format == Param::Format::NCHW ||
634
        param().format == Param::Format::NHWC) {
635
        img_dim = src.ndim - 2;
M
Megvii Engine Team 已提交
636 637 638
        megdnn_assert(
                filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, "%s",
                errmsg().c_str());
639 640

    } else {
M
Megvii Engine Team 已提交
641 642 643 644 645 646 647 648 649 650 651 652 653 654
        megdnn_assert(
                param().format == Param::Format::NHWCD4 ||
                param().format == Param::Format::NCHW4 ||
                param().format == Param::Format::NCHW4_NCHW ||
                param().format == Param::Format::NCHW4_NHWC ||
                param().format == Param::Format::NCHW4_NCHW32 ||
                param().format == Param::Format::NCHW44 ||
                param().format == Param::Format::NCHW44_DOT ||
                param().format == Param::Format::NCHW8 ||
                param().format == Param::Format::NCHW32 ||
                param().format == Param::Format::NCHW32_NCHW4 ||
                param().format == Param::Format::NCHW88 ||
                param().format == Param::Format::CHWN4 ||
                param().format == Param::Format::NCHW64);
655
        img_dim = src.ndim - 3;
656
        if ((param().format == Param::Format::NCHW88 ||
657
             param().format == Param::Format::NCHW44_DOT ||
658 659
             param().format == Param::Format::NCHW44) &&
            filter.ndim == 5) {
660 661
            img_dim = src.ndim - 2;
        }
M
Megvii Engine Team 已提交
662 663 664 665 666 667 668 669
        megdnn_assert(
                filter.ndim == img_dim + 3 ||
                        (filter.ndim == img_dim + 2 &&
                         (param().format == Param::Format::NCHW88 ||
                          param().format == Param::Format::NCHW44_DOT ||
                          param().format == Param::Format::NCHW44)) ||
                        filter.ndim == img_dim + 4 || filter.ndim == img_dim + 5,
                "%s", errmsg().c_str());
670 671 672
        if (param().format == Param::Format::NCHW4 ||
            param().format == Param::Format::NCHW4_NCHW ||
            param().format == Param::Format::NCHW4_NCHW32) {
M
Megvii Engine Team 已提交
673 674 675 676 677 678 679 680 681 682 683 684
            megdnn_assert(
                    src.ndim == 5 &&
                            (filter.ndim == 5 || filter.ndim == 6 ||
                             filter.ndim == 7) &&
                            src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
                    "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and "
                    "filter's ndim is "
                    "5 or 6, and "
                    "last shape "
                    "is 4 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
685 686 687 688
        }
        if (param().format == Param::Format::NCHW8) {
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
M
Megvii Engine Team 已提交
689
                            src[src.ndim - 1] == 8 && filter[filter.ndim - 1] == 8,
690 691 692 693 694
                    "NCHW8 require src and filter's ndim is 5 or 6, and last "
                    "shape is 8 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
        }
695 696
        if (param().format == Param::Format::NCHW32 ||
            param().format == Param::Format::NCHW32_NCHW4) {
M
Megvii Engine Team 已提交
697 698 699 700 701 702 703 704
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
                            src[src.ndim - 1] == 32 && filter[filter.ndim - 1] == 32,
                    "NCHW32/NCHW32_NCHW4 require src and filter's ndim "
                    "is 5 or 6, and last "
                    "shape is 32 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
705
        }
706
        if (param().format == Param::Format::NCHW88) {
M
Megvii Engine Team 已提交
707 708 709 710 711 712 713 714 715 716 717
            megdnn_assert(
                    (src.ndim == 4 && filter.ndim == 5 &&
                     filter[filter.ndim - 1] == 8) ||
                            (src.ndim == 5 &&
                             ((filter.ndim == 6 && filter[filter.ndim - 1] == 8) ||
                              (filter.ndim == 7 && filter[filter.ndim - 1] == 8 &&
                               filter[filter.ndim - 2] == 8)) &&
                             src[src.ndim - 1] == 8),
                    "NCHW88 require src ndim is 5 and filter's ndim is 6 "
                    ", and last shape two is 8 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
718
        }
719
        if (param().format == Param::Format::NCHW44 ||
720
            param().format == Param::Format::NCHW44_DOT) {
721 722
            //! support nchw44 filter change to 88 for int8 winogradf23_88 using
            //! MK8 mamtul
M
Megvii Engine Team 已提交
723 724 725 726 727 728 729 730 731 732 733 734 735 736 737
            megdnn_assert(
                    (src.ndim == 4 && filter.ndim == 5 &&
                     filter[filter.ndim - 1] == 4) ||
                            (src.ndim == 5 &&
                             ((filter.ndim == 6 && (filter[filter.ndim - 1] == 4 ||
                                                    filter[filter.ndim - 1] == 8)) ||
                              (filter.ndim == 7 &&
                               (filter[filter.ndim - 1] == 4 ||
                                filter[filter.ndim - 1] == 8) &&
                               (filter[filter.ndim - 2] == 4 ||
                                filter[filter.ndim - 2] == 8))) &&
                             src[src.ndim - 1] == 4),
                    "NCHW44 require src ndim is 5 and filter's ndim is 6 "
                    ", and last shape two is 4 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
738
        }
739 740 741
        if (param().format == Param::Format::CHWN4) {
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
M
Megvii Engine Team 已提交
742
                            src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
743 744 745 746 747
                    "CHWN4 require src and filter's ndim is 5 or 6, and last "
                    "shape is 4 "
                    "but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
        }
748
        if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
749 750 751 752 753 754
            megdnn_assert(
                    src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
                            src[src.ndim - 1] == 64 && filter[filter.ndim - 1] == 64,
                    "NCHW64 require src and filter's ndim is 5 or 6, and "
                    "last shape is 64 but got src %s, filter %s",
                    src.to_string().c_str(), filter.to_string().c_str());
755
        }
756
    }
M
Megvii Engine Team 已提交
757
    megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported");
758 759
    auto cflt = make_canonized_filter_meta(src.ndim, filter);
    if (param().format == Param::Format::NCHW ||
760
        param().format == Param::Format::NHWC) {
761 762
        size_t src_or_dst_c_pos = 0;
        size_t src_or_dst_spatial_start = 0;
763
        if (param().format == Param::Format::NCHW) {
764 765 766
            src_or_dst_c_pos = 1;
            src_or_dst_spatial_start = 2;
        } else {
M
Megvii Engine Team 已提交
767
            megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
768 769 770
            src_or_dst_c_pos = 3;
            src_or_dst_spatial_start = 1;
        }
M
Megvii Engine Team 已提交
771 772 773
        megdnn_assert(
                cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
                errmsg().c_str());
774 775 776 777 778 779 780 781 782
        dst.ndim = src.ndim;
        dst[0] = src[0];
        dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
        for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
            dst[i + src_or_dst_spatial_start] = infer_conv_shape(
                    src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
                    cflt.stride[i], cflt.padding[i]);
        }
    } else if (param().format == Param::Format::NCHW4) {
M
Megvii Engine Team 已提交
783 784 785 786 787 788
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
789 790 791 792 793
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
794 795 796 797
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
798 799
        dst[4] = 4;
    } else if (param().format == Param::Format::NCHW8) {
M
Megvii Engine Team 已提交
800 801 802 803 804 805
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
806 807 808 809 810
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 8 == 0);
        dst[1] = oc / 8;
M
Megvii Engine Team 已提交
811 812 813 814
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
815 816
        dst[4] = 8;
    } else if (param().format == Param::Format::NCHW32) {
M
Megvii Engine Team 已提交
817 818 819 820 821 822
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
823 824 825 826 827
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 32 == 0);
        dst[1] = oc / 32;
M
Megvii Engine Team 已提交
828 829 830 831
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
832
        dst[4] = 32;
833
    } else if (param().format == Param::Format::NCHW88) {
834
        megdnn_assert(
835
                src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
836
                "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
837 838 839 840 841
        dst.ndim = 5;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 8 == 0);
        dst[1] = oc / 8;
M
Megvii Engine Team 已提交
842 843 844 845
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
846 847
        dst[4] = 8;
        if (cflt.group == 1) {
M
Megvii Engine Team 已提交
848 849 850 851
            megdnn_assert(
                    cflt.icpg * cflt.group == src[1] * 8 ||
                            (cflt.icpg * cflt.group == src[1]),
                    "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
852 853
        }

854 855 856 857
    } else if (
            param().format == Param::Format::NCHW44 ||
            param().format == Param::Format::NCHW44_DOT) {
        megdnn_assert(
858
                src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
859
                "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
860 861 862 863 864
        dst.ndim = 5;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
865 866 867 868
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
869 870
        dst[4] = 4;
        if (cflt.group == 1) {
M
Megvii Engine Team 已提交
871 872 873 874
            megdnn_assert(
                    cflt.icpg * cflt.group == src[1] * 4 ||
                            (cflt.icpg * cflt.group == src[1]),
                    "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
875
        }
876
    } else if (param().format == Param::Format::CHWN4) {
M
Megvii Engine Team 已提交
877 878 879 880 881 882
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
883 884 885 886 887
        dst.ndim = src.ndim;
        dst[3] = src[3];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[0] = oc / 4;
M
Megvii Engine Team 已提交
888 889 890 891
        dst[1] = infer_conv_shape(
                src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
892
        dst[4] = 4;
893
    } else if (param().format == Param::Format::NCHW4_NCHW) {
M
Megvii Engine Team 已提交
894 895 896 897 898 899
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
900 901 902 903
        dst.ndim = 4;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        dst[1] = oc;
M
Megvii Engine Team 已提交
904 905 906 907
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
908
    } else if (param().format == Param::Format::NCHW4_NHWC) {
M
Megvii Engine Team 已提交
909 910 911 912 913 914
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
915 916
        dst.ndim = 4;
        dst[0] = src[0];
M
Megvii Engine Team 已提交
917 918 919 920
        dst[1] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[2] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
921 922
        auto oc = cflt.ocpg * cflt.group;
        dst[3] = oc;
923
    } else if (param().format == Param::Format::NCHW4_NCHW32) {
M
Megvii Engine Team 已提交
924 925 926 927 928 929
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
930 931 932 933 934
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 32 == 0);
        dst[1] = oc / 32;
M
Megvii Engine Team 已提交
935 936 937 938
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
939 940
        dst[4] = 32;
    } else if (param().format == Param::Format::NCHW32_NCHW4) {
M
Megvii Engine Team 已提交
941 942 943 944 945 946
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
947 948 949 950 951
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[1] = oc / 4;
M
Megvii Engine Team 已提交
952 953 954 955
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
956
        dst[4] = 4;
957
    } else if (param().format == Param::Format::NCHW64) {
M
Megvii Engine Team 已提交
958 959 960 961 962 963
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
964 965 966 967 968
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 64 == 0);
        dst[1] = oc / 64;
M
Megvii Engine Team 已提交
969 970 971 972
        dst[2] = infer_conv_shape(
                src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
973
        dst[4] = 64;
974 975
    } else {
        megdnn_assert(param().format == Param::Format::NHWCD4);
M
Megvii Engine Team 已提交
976 977 978 979 980 981
        megdnn_assert(
                src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
                src.ndim);
        megdnn_assert(
                cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u",
                errmsg().c_str(), cflt.icpg, cflt.group);
982 983 984 985 986
        dst.ndim = src.ndim;
        dst[0] = src[0];
        auto oc = cflt.ocpg * cflt.group;
        megdnn_assert(oc % 4 == 0);
        dst[2] = oc / 4;
M
Megvii Engine Team 已提交
987 988 989 990
        dst[1] = infer_conv_shape(
                src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        dst[3] = infer_conv_shape(
                src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
991 992 993
        megdnn_assert(src[4] == 4);
        dst[4] = 4;
    }
M
Megvii Engine Team 已提交
994
    if (!src.format.is_default() && !src.format.is_lowbit_aligned()) {  // propagate
M
Megvii Engine Team 已提交
995 996 997
        dst.format = src.format;
    } else {  // determined by dtype
        dst.format = TensorFormat(dst.dtype);
M
Megvii Engine Team 已提交
998
    }
999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012
    dst.init_contiguous_stride();
    return cflt;
}

/**
 * \warning: An explicit specialization shall be declared in a namespace
 * enclosing the specialized template. An explicit specialization whose
 * declarator-id is not qualified shall be declared in the nearest enclosing
 * namespace of the template, or, if the namespace is inline (7.3.1), any
 * namespace from its enclosing namespace set.
 * refer to:
 * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace
 */
template <>
M
Megvii Engine Team 已提交
1013 1014 1015 1016 1017
ConvolutionBase<param::Convolution>::CanonizedFilterMeta ConvolutionBase<
        param::Convolution>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1018 1019
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1020 1021 1022 1023 1024 1025 1026 1027 1028
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

template <>
M
Megvii Engine Team 已提交
1029 1030 1031 1032
ConvolutionBase<param::ConvBias>::CanonizedFilterMeta ConvolutionBase<param::ConvBias>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1033 1034
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1035 1036 1037 1038 1039 1040 1041 1042 1043
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

template <>
M
Megvii Engine Team 已提交
1044 1045 1046 1047 1048
ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta ConvolutionBase<
        param::BatchConvBias>::
        check_layout_fwd(
                const TensorLayout& src, const TensorLayout& filter,
                const TensorLayout& dst) const {
1049 1050
    megdnn_assert_contiguous(src);
    megdnn_assert_contiguous(filter);
1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062
    TensorLayout dst_expected;
    dst_expected.dtype = dst.dtype;

    auto ret = deduce_layout_fwd(src, filter, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);
    return ret;
}

void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) {
    check_or_deduce_dtype_fwd(src, filter, dst);
}

M
Megvii Engine Team 已提交
1063 1064
void ConvolutionForward::deduce_layout(
        const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
1065 1066 1067 1068
    deduce_layout_fwd(src, filter, dst);
}

ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec(
M
Megvii Engine Team 已提交
1069 1070
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) {
1071
    auto ret = check_layout_fwd(src, filter, dst);
1072
    auto required_workspace_in_bytes =
1073
            get_workspace_in_bytes(src, filter, dst, preprocessed_filter);
1074 1075 1076 1077
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

M
Megvii Engine Team 已提交
1078 1079 1080
ConvolutionBackwardData::CanonizedFilterMeta ConvolutionBackwardData::check_exec(
        const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad,
        size_t workspace_in_bytes) {
1081 1082 1083 1084 1085 1086 1087 1088 1089
    auto grad_fwd = grad;
    auto filter_fwd = filter;
    auto diff_fwd = diff;

    std::swap(grad_fwd.dtype, diff_fwd.dtype);

    grad_fwd.init_contiguous_stride();
    diff_fwd.init_contiguous_stride();
    auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
M
Megvii Engine Team 已提交
1090
    auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
1091 1092 1093 1094
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

M
Megvii Engine Team 已提交
1095
void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad) {
1096 1097 1098 1099 1100 1101
    SmallVector<DType> supported_dst_dtype;
    if (filter.category() == diff.category() &&
        filter.category() == DTypeCategory::FLOAT) {
        supported_dst_dtype.push_back(filter);
    } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
        supported_dst_dtype.push_back(dtype::Int32());
M
Megvii Engine Team 已提交
1102 1103 1104 1105 1106 1107
    } else if (
            (filter.enumv() == DTypeEnum::QuantizedS8 &&
             diff.enumv() == DTypeEnum::QuantizedS8) ||
            (filter.enumv() == DTypeEnum::Quantized8Asymm &&
             diff.enumv() == DTypeEnum::Quantized8Asymm)) {
        supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff)));
1108 1109 1110 1111
        if (grad.valid() && grad.enumv() == diff.enumv()) {
            supported_dst_dtype.push_back(grad);
        }
    } else {
M
Megvii Engine Team 已提交
1112 1113
        megdnn_throw(ssprintf(
                "unsupported input / diff DType: %s x %s", filter.name(), diff.name()));
1114 1115 1116 1117
    }
    if (!grad.valid()) {
        grad = supported_dst_dtype.at(0);
    } else {
M
Megvii Engine Team 已提交
1118 1119 1120 1121
        megdnn_assert(
                vec_contains(supported_dst_dtype, grad),
                "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(),
                grad.name());
1122
    }
M
Megvii Engine Team 已提交
1123 1124
    megdnn_assert(
            param().compute_mode != Param::ComputeMode::FLOAT32
1125
#if !MEGDNN_DISABLE_FLOAT16
M
Megvii Engine Team 已提交
1126 1127
                    || filter.enumv() == DTypeEnum::Float16 ||
                    filter.enumv() == DTypeEnum::BFloat16
1128
#endif
M
Megvii Engine Team 已提交
1129 1130 1131
            ,
            "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
            "input / output.");
1132 1133
}

M
Megvii Engine Team 已提交
1134 1135
void ConvolutionBackwardData::deduce_layout(
        const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
1136 1137 1138 1139
    auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
    MEGDNN_MARK_USED_VAR(errmsg);
    megdnn_assert_contiguous(filter);
    megdnn_assert_contiguous(diff);
M
Megvii Engine Team 已提交
1140
    megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
1141 1142 1143 1144 1145 1146
    megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());

    deduce_dtype(filter.dtype, diff.dtype, grad.dtype);

    auto cflt = make_canonized_filter_meta(diff.ndim, filter);

M
Megvii Engine Team 已提交
1147
    auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
        MEGDNN_MARK_USED_VAR(errmsg);
        auto i = (out - 1) * stride + filter;
        megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
        return i - pad * 2;
    };

    if (param().format == Param::Format::NCHW ||
        param().format == Param::Format::NHWC) {
        size_t src_or_dst_c_pos = 0;
        size_t src_or_dst_spatial_start = 0;
        if (param().format == Param::Format::NCHW) {
            src_or_dst_c_pos = 1;
            src_or_dst_spatial_start = 2;
        } else {
M
Megvii Engine Team 已提交
1162
            megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
1163 1164 1165
            src_or_dst_c_pos = 3;
            src_or_dst_spatial_start = 1;
        }
M
Megvii Engine Team 已提交
1166 1167 1168
        megdnn_assert(
                cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s",
                errmsg().c_str());
1169 1170 1171 1172
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
        for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
M
Megvii Engine Team 已提交
1173 1174 1175
            grad[i + src_or_dst_spatial_start] =
                    deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
                           cflt.stride[i], cflt.padding[i]);
1176
        }
1177
    } else if (param().format == Param::Format::NCHW4) {
M
Megvii Engine Team 已提交
1178 1179 1180
        megdnn_assert(
                diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu",
                diff.ndim);
1181
        megdnn_assert(cflt.group == 1, "%s", errmsg().c_str());
M
Megvii Engine Team 已提交
1182
        megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str());
1183 1184 1185 1186 1187
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        auto ic = cflt.icpg * cflt.group;
        megdnn_assert(ic % 4 == 0);
        grad[1] = ic / 4;
M
Megvii Engine Team 已提交
1188 1189 1190 1191
        grad[2] = deduce(
                diff[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        grad[3] = deduce(
                diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
1192 1193
        megdnn_assert(diff[4] == 4);
        grad[4] = 4;
1194 1195
    } else {
        megdnn_assert(param().format == Param::Format::NHWCD4);
M
Megvii Engine Team 已提交
1196 1197 1198 1199
        megdnn_assert(
                diff.ndim == 5, "valid diff ndim for NHWCD4, expected=5, got=%zu",
                diff.ndim);
        megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", errmsg().c_str());
1200 1201 1202 1203 1204
        grad.ndim = diff.ndim;
        grad[0] = diff[0];
        auto ic = cflt.icpg * cflt.group;
        megdnn_assert(ic % 4 == 0);
        grad[2] = ic / 4;
M
Megvii Engine Team 已提交
1205 1206 1207 1208
        grad[1] = deduce(
                diff[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
        grad[3] = deduce(
                diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
1209 1210 1211 1212 1213 1214 1215
        megdnn_assert(diff[4] == 4);
        grad[4] = 4;
    }
    grad.format = diff.format;
    grad.init_contiguous_stride();
}

M
Megvii Engine Team 已提交
1216 1217 1218 1219 1220 1221 1222 1223
ConvolutionBackwardFilter::CanonizedFilterMeta ConvolutionBackwardFilter::check_exec(
        const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
        size_t workspace_in_bytes) {
    megdnn_assert(
            src.dtype.category() == DTypeCategory::FLOAT &&
                    diff.dtype.category() == DTypeCategory::FLOAT &&
                    grad.dtype.category() == DTypeCategory::FLOAT,
            "only float type is supported for conv backward filter");
1224 1225 1226 1227 1228 1229
    auto src_fwd = src;
    auto diff_fwd = diff;

    src_fwd.init_contiguous_stride();
    diff_fwd.init_contiguous_stride();
    auto ret = check_layout_fwd(src_fwd, grad, diff_fwd);
1230 1231 1232 1233 1234 1235 1236 1237
    auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    return ret;
}

}  // namespace megdnn

// vim: syntax=cpp.doxygen