helper.h 35.7 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file dnn/src/naive/convolution/helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
 */
#pragma once

#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"

#include <cstring>

namespace megdnn {
namespace naive {
namespace convolution {

struct GroupCounter {
    const size_t grp_size;
    size_t cur_grp = 0, cur_off = 0;

    explicit GroupCounter(size_t grp_size) : grp_size{grp_size} {}

    void next() {
        if ((++cur_off) == grp_size) {
            cur_off = 0;
            ++cur_grp;
        }
    }
};

struct StrategyFwd {
    template <typename st, typename ft, typename ct>
    static void on(st& s, ft& f, ct& d, DType, DType, DType) {
        d += static_cast<ct>(s) * static_cast<ct>(f);
    }

    template <typename ct, typename dt>
    static void write(ct& d, dt& dst) {
        dst = static_cast<dt>(d);
    }

    template <typename dt>
    static void init_dval(dt& d) {
        d = static_cast<dt>(0);
    }
};

// Explicit specialization of member function template is not allowed to happen
// in class scope, this is a defect of C++ specification which will be fixed in
// C++17. We workaround this by marking the implmentation as inline and move
// out of class definition.
template <>
inline void StrategyFwd::on(dt_quint8& s, dt_quint8& f, dt_qint32& d,
                            DType src_dt, DType filt_dt, DType) {
    auto cast = [](const dt_quint8& val, DType dt) {
        return dt_qint32(static_cast<int32_t>(val.as_uint8()) -
                         dt.param<dtype::Quantized8Asymm>().zero_point);
    };
    d += cast(s, src_dt) * cast(f, filt_dt);
}

template <>
inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_qint32& d, DType,
                            DType, DType) {
    auto cast = [](const dt_qint8& val) {
        return dt_qint32(static_cast<int32_t>(val.as_int8()));
    };
    d += cast(s) * cast(f);
}

struct StrategyBwdData {
    template <typename st, typename ft, typename dt>
    static void on(st& s, ft& f, dt& d, DType, DType, DType) {
        s += static_cast<st>(f) * static_cast<st>(d);
    }

    template <typename ct, typename dt>
    static void write(ct&, dt&) {}

    template <typename dt>
    static void init_dval(dt&) {}
};

template <>
inline void StrategyBwdData::on(int& s, signed char& f, signed char& d, DType,
                                DType, DType) {
    auto cast = [](signed char& val) {
        return static_cast<int32_t>(((megdnn::dt_qint8)val).as_int8());
    };
    s += cast(f) * cast(d);
}

template <>
inline void StrategyBwdData::on(dt_qint32& s, dt_quint8& f, dt_quint8& d, DType,
                                DType filt_dt, DType dst_dt) {
    auto cast = [](const dt_quint8& val, DType dt) {
        return dt_qint32(static_cast<int32_t>(val.as_uint8()) -
                         dt.param<dtype::Quantized8Asymm>().zero_point);
    };
    s += cast(f, filt_dt) * cast(d, dst_dt);
}

template <>
inline void StrategyBwdData::on(dt_qint32& s, dt_qint8& f, dt_qint8& d, DType,
                                DType, DType) {
    auto cast = [](const dt_qint8& val) {
        return dt_qint32(static_cast<int32_t>(val.as_int8()));
    };
    s += cast(f) * cast(d);
}

struct StrategyBwdFlt {
    template <typename st, typename ft, typename dt>
    static void on(st& s, ft& f, dt& d, DType, DType, DType) {
        f += static_cast<ft>(s) * static_cast<ft>(d);
    }

    template <typename ct, typename dt>
    static void write(ct&, dt&) {}

    template <typename dt>
    static void init_dval(dt&) {}
};

struct ConvFilterVisitor {
    template <typename ftype>
    static ftype* get_current_ptr(ftype* fptr, size_t /* batch */,
                                  size_t /* oc */, size_t /* oh */,
                                  size_t /* ow */, size_t /* filter_sizes*/) {
        return fptr;
    }
};

template <typename stype, typename ftype, typename dtype, typename comp_type,
          class Strategy, typename FilterMeta,
          typename FilterVisitor = ConvFilterVisitor>
void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
               _megdnn_tensor_out dst, const FilterMeta& filter_meta) {
    size_t spatial_start, channel_pos, batch_pos;
    using Format = param::Convolution::Format;
    if (filter_meta.format == Format::NCHW ||
        filter_meta.format == Format::NCHW88 ||
149
        filter_meta.format == Format::NCHW44 ||
150
        filter_meta.format == Format::NCHW44_DOT ||
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        filter_meta.format == Format::NCHW4 ||
        filter_meta.format == Format::NCHW8 ||
        filter_meta.format == Format::NCHW32) {
        spatial_start = 2;
        channel_pos = 1;
        batch_pos = 0;
    } else if (filter_meta.format == Format::CHWN4) {
        spatial_start = 1;
        channel_pos = 0;
        batch_pos = 3;
    } else {
        megdnn_assert(filter_meta.format == Format::NHWC,
                      "invalid conv format");
        spatial_start = 1;
        channel_pos = 3;
        batch_pos = 0;
    }

    auto N = src.layout.shape[batch_pos], IH = src.layout.shape[spatial_start],
         IW = src.layout.shape[spatial_start + 1];
    auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1];
    auto OC = dst.layout.shape[channel_pos],
         OH = dst.layout.shape[spatial_start],
         OW = dst.layout.shape[spatial_start + 1];

    if (filter_meta.format == Format::NCHW4 ||
177
        filter_meta.format == Format::CHWN4 ||
178
        filter_meta.format == Format::NCHW44_DOT ||
179
        filter_meta.format == Format::NCHW44) {
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
        OC *= 4;
    } else if (filter_meta.format == Format::NCHW8 ||
               filter_meta.format == Format::NCHW88) {
        OC *= 8;
    } else if (filter_meta.format == Format::NCHW32) {
        OC *= 32;
    }

    size_t FS_G, FS_OC, FS_IC, FS_SPATIAL;
    if (filter_meta.format == Format::NCHW ||
        filter_meta.format == Format::NCHW4 ||
        filter_meta.format == Format::NCHW8 ||
        filter_meta.format == Format::NCHW32) {
        // g, oc, ic, fh, fw
        FS_SPATIAL = 1;
        FS_IC = FH * FW;
        FS_OC = FS_IC * filter_meta.icpg;
        FS_G = FS_OC * filter_meta.ocpg;
    } else if (filter_meta.format == Format::CHWN4) {
        // g, ic, fh, fw, oc, pack_size
        FS_SPATIAL = filter_meta.ocpg * 4;
        FS_IC = FH * FW * FS_SPATIAL;
        FS_OC = 4;
        FS_G = FS_IC * filter_meta.icpg;
    } else if (filter_meta.format == Format::NCHW88) {
        if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
            src.layout.ndim == 5 && filter_meta.ocpg == 1) {
            FS_SPATIAL = 8;
            FS_IC = FH * FW * FS_SPATIAL;
            FS_OC = FS_IC * filter_meta.icpg;
            FS_G = FS_OC * filter_meta.ocpg;
        } else {
            if (src.layout.ndim == 4 && dst.layout.ndim == 5) {
                FS_IC = 8;
                FS_SPATIAL = filter_meta.icpg * FS_IC;
                FS_OC = FH * FW * FS_SPATIAL;
                FS_G = FS_OC * filter_meta.ocpg / 8;
            } else {
                FS_SPATIAL = 8 * 8;
                FS_IC = FH * FW * FS_SPATIAL;
                FS_OC = FS_IC * filter_meta.icpg / 8;
                FS_G = FS_OC * filter_meta.ocpg / 8;
            }
        }
224 225
    } else if (filter_meta.format == Format::NCHW44 ||
               filter_meta.format == Format::NCHW44_DOT) {
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
        if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
            src.layout.ndim == 5 && filter_meta.ocpg == 1) {
            FS_SPATIAL = 4;
            FS_IC = FH * FW * FS_SPATIAL;
            FS_OC = FS_IC * filter_meta.icpg;
            FS_G = FS_OC * filter_meta.ocpg;
        } else {
            if (src.layout.ndim == 4 && dst.layout.ndim == 5) {
                FS_IC = 4;
                FS_SPATIAL = filter_meta.icpg * FS_IC;
                FS_OC = FH * FW * FS_SPATIAL;
                FS_G = FS_OC * filter_meta.ocpg / 4;
            } else {
                FS_SPATIAL = 4 * 4;
                FS_IC = FH * FW * FS_SPATIAL;
                FS_OC = FS_IC * filter_meta.icpg / 4;
                FS_G = FS_OC * filter_meta.ocpg / 4;
            }
        }
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
    } else {
        // g, oc, fh, fw, ic
        megdnn_assert(filter_meta.format == Format::NHWC);
        FS_IC = 1;
        FS_SPATIAL = filter_meta.icpg;
        FS_OC = FS_SPATIAL * FH * FW;
        FS_G = FS_OC * filter_meta.ocpg;
    }
    int ph = filter_meta.padding[0], pw = filter_meta.padding[1];
    size_t sh = filter_meta.stride[0], sw = filter_meta.stride[1];
    int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1];
    stype* __restrict sptr = src.compatible_ptr<stype>();
    dtype* __restrict dptr = dst.compatible_ptr<dtype>();

    int h_offset = -ph, w_offset = -pw;
    if (filter_meta.should_flip) {
        h_offset += filter_meta.dilated_spatial[0] - 1;
        w_offset += filter_meta.dilated_spatial[1] - 1;
        dh = -dh;
        dw = -dw;
    }

    auto get_linear_addr = [&filter_meta, &src](ptrdiff_t n, ptrdiff_t c,
                                                ptrdiff_t h, ptrdiff_t w,
                                                const TensorLayout& layout,
                                                bool is_output) -> ptrdiff_t {
        if (filter_meta.format == Format::NCHW) {
            return n * layout.stride[0] + c * layout.stride[1] +
                   h * layout.stride[2] + w * layout.stride[3];
        } else if (filter_meta.format == Format::NHWC) {
            return n * layout.stride[0] + h * layout.stride[1] +
                   w * layout.stride[2] + c * layout.stride[3];
        } else if (filter_meta.format == Format::NCHW8 ||
                   filter_meta.format == Format::NCHW88) {
            if (filter_meta.format == Format::NCHW88 && !is_output &&
                src.layout.ndim == 4) {
                return n * layout.stride[0] + c * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3];
            } else {
                return n * layout.stride[0] + (c / 8) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c & 0b111) * layout.stride[4];
            }
288 289
        } else if (filter_meta.format == Format::NCHW44 ||
                   filter_meta.format == Format::NCHW44_DOT) {
290 291 292 293 294 295 296 297 298
            if (filter_meta.format == Format::NCHW44 && !is_output &&
                src.layout.ndim == 4) {
                return n * layout.stride[0] + c * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3];
            } else {
                return n * layout.stride[0] + (c / 4) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c % 4) * layout.stride[4];
            }
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
        } else if (filter_meta.format == Format::NCHW32) {
            return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
                   h * layout.stride[2] + w * layout.stride[3] +
                   (c & 0x1F) * layout.stride[4];
        } else if (filter_meta.format == Format::CHWN4) {
            return (c / 4) * layout.stride[0] + h * layout.stride[1] +
                   w * layout.stride[2] + n * layout.stride[3] +
                   (c % 4) * layout.stride[4];
        } else {
            megdnn_assert(filter_meta.format == Format::NCHW4,
                          "invalid conv format");
            return n * layout.stride[0] + (c / 4) * layout.stride[1] +
                   h * layout.stride[2] + w * layout.stride[3] +
                   (c & 0b11) * layout.stride[4];
        }
    };

    auto get_filter_addr = [&](GroupCounter& gc_out, size_t ic, size_t ic0,
                               size_t fh, size_t fw) {
        if (filter_meta.format == Format::NCHW4) {
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
                   (ic - ic0) / 4 * FS_IC * 4 +
                   (fh * FW + fw) * FS_SPATIAL * 4 + ((ic - ic0) & 0b11);
        } else if (filter_meta.format == Format::NCHW8) {
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
                   (ic - ic0) / 8 * FS_IC * 8 +
                   (fh * FW + fw) * FS_SPATIAL * 8 + ((ic - ic0) & 0b111);
        } else if (filter_meta.format == Format::NCHW32) {
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
                   (ic - ic0) / 32 * FS_IC * 32 +
                   (fh * FW + fw) * FS_SPATIAL * 32 + ((ic - ic0) & 0x1F);
        } else if (filter_meta.format == Format::CHWN4) {
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
                   (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
                   ((ic - ic0) % 4);
334 335 336 337 338 339
        } else if (filter_meta.format == Format::NCHW88 ||
                   filter_meta.format == Format::NCHW44) {
            size_t pack_c_size = 4_z;
            if(filter_meta.format == Format::NCHW88){
                pack_c_size = 8_z;
            }
340 341
            if (src.layout.ndim == 4) {
                // ic < 8, input is nchw
342 343
                return gc_out.cur_grp * FS_G +
                       gc_out.cur_off / pack_c_size * FS_OC +
344
                       (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
345
                       gc_out.cur_off % pack_c_size;
346 347 348
            } else if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
                       filter_meta.ocpg == 1 && src.layout.ndim == 5) {
                // dw case
349 350 351 352
                return gc_out.cur_grp / pack_c_size * FS_G +
                       gc_out.cur_off * FS_OC + (ic - ic0) * FS_IC +
                       (fh * FW + fw) * FS_SPATIAL +
                       gc_out.cur_grp % pack_c_size;
353 354
            } else if (src.layout.ndim == 5) {
                // normal case
355 356 357 358 359 360
                return gc_out.cur_grp * FS_G +
                       gc_out.cur_off / pack_c_size * FS_OC +
                       (ic - ic0) / pack_c_size * FS_IC +
                       (fh * FW + fw) * FS_SPATIAL +
                       ((ic - ic0) % pack_c_size) * pack_c_size +
                       gc_out.cur_off % pack_c_size;
361
            } else {
362 363 364
                megdnn_throw(
                        "nchw88/nchw44 naive not support this input and "
                        "output\n");
365
            }
366
        } else if (filter_meta.format == Format::NCHW44_DOT) {
367
            if (src.layout.ndim == 4) {
368
                // ic < 4, input is nchw
369 370 371 372 373 374 375 376 377 378 379 380 381
                return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
                       (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
                       gc_out.cur_off % 4;
            } else if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
                       filter_meta.ocpg == 1 && src.layout.ndim == 5) {
                // dw case
                return gc_out.cur_grp / 4 * FS_G + gc_out.cur_off * FS_OC +
                       (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL +
                       gc_out.cur_grp % 4;
            } else if (src.layout.ndim == 5) {
                // normal case
                return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
                       (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
382
                       (gc_out.cur_off % 4) * 4 + ((ic - ic0) % 4);
383
            } else {
384 385
                megdnn_throw(
                        "nchw44_dot naive not support this input and output\n");
386
            }
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
        } else {
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
                   (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL;
        }
    };
    size_t filter_sizes = filter_meta.ocpg * filter_meta.icpg * FH * FW;
    for (size_t n = 0; n < N; ++n) {
        GroupCounter gc_out{filter_meta.ocpg};
        for (size_t oc = 0; oc < OC; ++oc, gc_out.next())
            for (size_t oh = 0; oh < OH; ++oh)
                for (size_t ow = 0; ow < OW; ++ow) {
                    comp_type dval = dptr[get_linear_addr(n, oc, oh, ow,
                                                          dst.layout, true)];
                    ftype* fptr_cur = FilterVisitor::template get_current_ptr(
                            fptr, n, oc, oh, ow, filter_sizes);
                    Strategy::init_dval(dval);

                    for (size_t fh = 0; fh < FH; ++fh)
                        for (size_t fw = 0; fw < FW; ++fw) {
                            size_t ih = sh * oh + fh * dh + h_offset,
                                   iw = sw * ow + fw * dw + w_offset;
                            // here ih and iw are represented in unsigned int
                            // they will become very large if underflow occurs
                            if (ih < IH && iw < IW) {
                                size_t ic0 = gc_out.cur_grp * filter_meta.icpg,
                                       ic1 = ic0 + filter_meta.icpg;
                                for (size_t ic = ic0; ic < ic1; ++ic) {
                                    stype& sval = sptr[get_linear_addr(
                                            n, ic, ih, iw, src.layout, false)];
                                    ftype& fval = fptr_cur[get_filter_addr(
                                            gc_out, ic, ic0, fh, fw)];
                                    Strategy::on(sval, fval, dval,
                                                 src.layout.dtype,
                                                 filter_meta.dtype,
                                                 dst.layout.dtype);
                                }
                            }
                        }
                    Strategy::write(dval,
                                    dptr[get_linear_addr(n, oc, oh, ow,
                                                         dst.layout, true)]);
                }
    }
}

template <typename stype, typename ftype, typename dtype, typename comp_type,
          class Strategy, typename FilterMeta,
          typename FilterVisitor = ConvFilterVisitor>
void compute2d_hwcd4(_megdnn_tensor_in src, _megdnn_tensor_in filter,
                     _megdnn_tensor_out dst, const FilterMeta& filter_meta) {
    // The filter's layout is (G, OC/4, FH, FW, IC, 4) when using mad
    // and (G, OC/4, FH, FW, IC/4, 4, 4) when using dot.
    bool use_dot = false;
    if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
        src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
        (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 &&
         (filter.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
          filter.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm)))
        use_dot = true;

    using Format = param::Convolution::Format;
    megdnn_assert(filter_meta.format == Format::NHWCD4);
    auto N = src.layout.shape[0], IH = src.layout.shape[1],
         IW = src.layout.shape[3];
    auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1];
    auto OC = dst.layout.shape[2] * 4, OH = dst.layout.shape[1],
         OW = dst.layout.shape[3];
    int ph = filter_meta.padding[0], pw = filter_meta.padding[1];
    size_t sh = filter_meta.stride[0], sw = filter_meta.stride[1];
    int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1];
    stype* __restrict sptr = src.compatible_ptr<stype>();
    ftype* __restrict fptr = filter.compatible_ptr<ftype>();
    dtype* __restrict dptr = dst.compatible_ptr<dtype>();

    megdnn_assert(!filter_meta.should_flip);
    int h_offset = -ph, w_offset = -pw;

    auto get_linear_addr = [](size_t n, size_t c, size_t h, size_t w,
                              const TensorLayout& layout) -> size_t {
        return n * layout.stride[0] + h * layout.stride[1] +
               (c / 4) * layout.stride[2] + w * layout.stride[3] +
               c % 4 * layout.stride[4];
    };

    size_t FS_G, FS_OCB, FS_SPATIAL;
    if (!use_dot && filter.layout.ndim == 5) {
        if (filter_meta.ocpg == 1 && filter_meta.icpg == 1) {
            // chanwise conv, (G/4, 1, FH, FW, 4)
            FS_G = filter.layout.stride[0];
            FS_OCB = 0;
            FS_SPATIAL = 4;
        } else {
            // dense conv, (OC/4, FH, FW, IC, 4)
            FS_G = 0;
            FS_OCB = filter.layout.stride[0];
            FS_SPATIAL = filter.layout.stride[2];
        }
    } else if (!use_dot && filter.layout.ndim == 6) {
        // group conv, (G, OC/4, FH, FW, IC, 4)
        FS_G = filter.layout.stride[0];
        FS_OCB = filter.layout.stride[1];
        FS_SPATIAL = filter.layout.stride[3];
    } else if (use_dot && filter.layout.ndim == 6) {
        // dense conv used dot, (OC/4, FH, FW, IC/4, 4, 4)
        FS_G = 0;
        FS_OCB = filter.layout.stride[0];
        FS_SPATIAL = filter.layout.stride[2];
    } else if (use_dot && filter.layout.ndim == 7) {
        // group conv used dot, (G, OC/4, FH, FW, IC/4, 4, 4)
        FS_G = filter.layout.stride[0];
        FS_OCB = filter.layout.stride[1];
        FS_SPATIAL = filter.layout.stride[3];
    } else if (use_dot && filter.layout.ndim == 5 && filter_meta.ocpg == 1 &&
               filter_meta.icpg == 1) {
        // chanwise conv, (G/4, 1, FH, FW, 4)
        FS_G = filter.layout.stride[0];
        FS_OCB = 0;
        FS_SPATIAL = 4;
    } else {
        megdnn_assert(0, "invalid filter layout");
    }

    auto get_filter_addr = [&use_dot, &FS_G, &FS_OCB, &FS_SPATIAL, &FW,
                            &filter_meta](size_t group, size_t offset,
                                          size_t fh, size_t fw,
                                          size_t c) -> size_t {
        if (filter_meta.ocpg == 1 && filter_meta.icpg == 1) {
            return (group / 4) * FS_G + (fh * FW + fw) * FS_SPATIAL +
                   (group % 4);
        } else if (!use_dot) {
            return group * FS_G + (offset / 4) * FS_OCB +
                   (fh * FW + fw) * FS_SPATIAL + c * 4 + (offset % 4);
        } else {
            megdnn_assert(use_dot);
            return group * FS_G + (offset / 4) * FS_OCB +
                   (fh * FW + fw) * FS_SPATIAL + (c / 4) * 16 +
                   (offset % 4) * 4 + (c % 4);
        }
    };

    size_t filter_sizes = filter_meta.ocpg * filter_meta.icpg * FH * FW;
    for (size_t n = 0; n < N; ++n) {
        GroupCounter gc_out{filter_meta.ocpg};
        for (size_t oc = 0; oc < OC; ++oc, gc_out.next())
            for (size_t oh = 0; oh < OH; ++oh)
                for (size_t ow = 0; ow < OW; ++ow) {
                    comp_type dval =
                            dptr[get_linear_addr(n, oc, oh, ow, dst.layout)];
                    Strategy::init_dval(dval);
                    ftype* fptr_cur = FilterVisitor::template get_current_ptr(
                            fptr, n, oc, oh, ow, filter_sizes);

                    for (size_t fh = 0; fh < FH; ++fh)
                        for (size_t fw = 0; fw < FW; ++fw) {
                            size_t ih = sh * oh + fh * dh + h_offset,
                                   iw = sw * ow + fw * dw + w_offset;
                            // here ih and iw are represented in unsigned int
                            // they will become very large if underflow occurs
                            if (ih < IH && iw < IW) {
                                size_t ic0 = gc_out.cur_grp * filter_meta.icpg,
                                       ic1 = ic0 + filter_meta.icpg;
                                for (size_t ic = ic0; ic < ic1; ++ic) {
                                    stype& sval = sptr[get_linear_addr(
                                            n, ic, ih, iw, src.layout)];
                                    ftype& fval = fptr_cur[get_filter_addr(
                                            gc_out.cur_grp, gc_out.cur_off, fh,
                                            fw, ic - ic0)];
                                    Strategy::on(sval, fval, dval,
                                                 src.layout.dtype,
                                                 filter_meta.dtype,
                                                 dst.layout.dtype);
                                }
                            }
                        }
                    Strategy::write(
                            dval,
                            dptr[get_linear_addr(n, oc, oh, ow, dst.layout)]);
                }
    }
}

//! forward with only filter ptr
template <typename stype, typename ftype, typename dtype, typename comp_type>
void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
             const Convolution::CanonizedFilterMeta& filter_meta) {
    megdnn_assert(filter_meta.spatial_ndim == 2);
    megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW ||
                  filter_meta.format == param::Convolution::Format::NHWC ||
                  filter_meta.format == param::Convolution::Format::NCHW88 ||
576
                  filter_meta.format == param::Convolution::Format::NCHW44 ||
577
                  filter_meta.format == param::Convolution::Format::NCHW44_DOT ||
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
                  filter_meta.format == param::Convolution::Format::NCHW4);
    compute2d<stype, ftype, dtype, comp_type, StrategyFwd>(
            src, const_cast<ftype*>(fptr), dst, filter_meta);
}

//! forward with full filter (for API compatibility)
template <typename stype, typename ftype, typename dtype, typename comp_type>
void forward(_megdnn_tensor_in src, _megdnn_tensor_in filter,
             _megdnn_tensor_out dst,
             const Convolution::CanonizedFilterMeta& filter_meta) {
    if (filter_meta.format == param::Convolution::Format::NHWCD4) {
        return compute2d_hwcd4<stype, ftype, dtype, comp_type, StrategyFwd>(
                src, filter, dst, filter_meta);
    }
    return forward<stype, ftype, dtype, comp_type>(
            src, filter.compatible_ptr<ftype>(), dst, filter_meta);
}

template <typename ftype, typename dtype, typename gtype>
void backward_data(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
                   _megdnn_tensor_out grad,
                   const Convolution::CanonizedFilterMeta& filter_meta) {
    megdnn_assert(grad.layout.is_contiguous());
    memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
    megdnn_assert(filter_meta.spatial_ndim == 2);
    if (filter_meta.format == param::Convolution::Format::NHWCD4) {
        return compute2d_hwcd4<gtype, ftype, dtype, dtype, StrategyBwdData>(
                grad, filter, diff, filter_meta);
    }
    compute2d<gtype, ftype, dtype, dtype, StrategyBwdData>(
            grad, filter.compatible_ptr<ftype>(), diff, filter_meta);
}

template <typename stype, typename dtype, typename gtype>
void backward_filter(_megdnn_tensor_in src, _megdnn_tensor_in diff,
                     _megdnn_tensor_out grad,
                     const Convolution::CanonizedFilterMeta& filter_meta) {
    megdnn_assert(grad.layout.is_contiguous());
    memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
    megdnn_assert(filter_meta.spatial_ndim == 2);
    compute2d<stype, gtype, dtype, dtype, StrategyBwdFlt>(
            src, grad.compatible_ptr<gtype>(), diff, filter_meta);
}

template <typename stype, typename ftype, typename dtype, typename comp_type,
          typename FilterMeta, typename FilterVisitor = ConvFilterVisitor>
void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
                  _megdnn_tensor_in bias, _megdnn_tensor_out dst,
                  dt_byte* /* workspace_ptr */, const FilterMeta& filter_meta) {
    megdnn_assert(filter_meta.spatial_ndim == 2);
    switch (filter_meta.format) {
        case param::Convolution::Format::NCHW:
        case param::Convolution::Format::NCHW88:
631
        case param::Convolution::Format::NCHW44:
632
        case param::Convolution::Format::NCHW44_DOT:
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
        case param::Convolution::Format::NHWC:
        case param::Convolution::Format::NCHW4:
        case param::Convolution::Format::NCHW8:
        case param::Convolution::Format::NCHW32:
        case param::Convolution::Format::CHWN4:
            compute2d<stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta,
                      FilterVisitor>(src, filter.compatible_ptr<ftype>(), dst,
                                     filter_meta);
            break;
        case param::Convolution::Format::NHWCD4:
            compute2d_hwcd4<stype, ftype, dtype, comp_type, StrategyFwd,
                            FilterMeta, FilterVisitor>(src, filter, dst,
                                                       filter_meta);
            break;
        default:
            megdnn_assert_internal(0);
    }

    //! we can not decide with bias.raw_ptr, as non bias the raw_ptr is not
    //! nullptr
    if (bias.layout.ndim != 0) {
        if (dst.layout.eq_shape(bias.layout) &&
            dst.layout.dtype.enumv() == bias.layout.dtype.enumv()) {
            dtype* dst_ptr = dst.compatible_ptr<dtype>();
            dtype* bias_ptr = bias.compatible_ptr<dtype>();
            for (size_t i = 0; i < dst.layout.span().dist_elem(); i++) {
                comp_type val = static_cast<comp_type>(dst_ptr[0]) +
                                static_cast<comp_type>(bias_ptr[0]);
                dst_ptr[0] = val;
                dst_ptr++;
                bias_ptr++;
            }
            return;
        }

        using Format = param::ConvBias::Format;
        switch (filter_meta.format) {
            case Format::NCHW: {
                int dst_batch = dst.layout.shape[0];
                int dst_channel = dst.layout.shape[1];
                int chann_stride = dst.layout.shape[2] * dst.layout.shape[3];
                dtype* dst_ptr = dst.compatible_ptr<dtype>();

                for (int batch = 0; batch < dst_batch; ++batch) {
                    for (int chan = 0; chan < dst_channel; ++chan) {
                        dtype bias_val = bias.compatible_ptr<dtype>()[chan];
                        for (int i = 0; i < chann_stride; ++i, ++dst_ptr) {
                            comp_type val = static_cast<comp_type>(dst_ptr[0]) +
                                            static_cast<comp_type>(bias_val);
                            dst_ptr[0] = val;
                        }
                    }
                }
                break;
            };
#define BIAS_ADD_NCHWx(_pack_size)                                        \
    do {                                                                  \
        megdnn_assert(dst.layout.is_contiguous());                        \
        int dst_batch = dst.layout.shape[0];                              \
        int dst_channel = dst.layout.shape[1] * (_pack_size);             \
        int chann_stride = dst.layout.shape[2] * dst.layout.shape[3];     \
        dtype* dst_ptr = dst.compatible_ptr<dtype>();                     \
        for (int batch = 0; batch < dst_batch; ++batch) {                 \
            for (int chan = 0; chan < dst_channel; ++chan) {              \
                dtype bias_val = bias.compatible_ptr<dtype>()[chan];      \
                for (int i = 0; i < chann_stride; ++i) {                  \
                    int idx = batch * dst_channel * chann_stride +        \
                              (chan / (_pack_size)) *                     \
                                      (chann_stride * (_pack_size)) +     \
                              i * (_pack_size) + chan % (_pack_size);     \
                    dst_ptr[idx] = static_cast<comp_type>(dst_ptr[idx]) + \
                                   static_cast<comp_type>(bias_val);      \
                }                                                         \
            }                                                             \
        }                                                                 \
    } while (0)
709
            case Format::NCHW44:
710
            case Format::NCHW44_DOT:
711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
            case Format::NCHW4: {
                BIAS_ADD_NCHWx(4);
                break;
            };
            case Format::NCHW8: {
                BIAS_ADD_NCHWx(8);
                break;
            };
            case Format::NCHW32: {
                BIAS_ADD_NCHWx(32);
                break;
            };
            case Format::NCHW88: {
                BIAS_ADD_NCHWx(8);
                break;
            };
#define BIAS_ADD_CHWNx(_pack_size)                                            \
    do {                                                                      \
        megdnn_assert(dst.layout.is_contiguous());                            \
        int dst_batch = dst.layout.shape[3];                                  \
        int dst_channel = dst.layout.shape[0] * (_pack_size);                 \
        int chann_stride =                                                    \
                dst.layout.shape[1] * dst.layout.shape[2] * dst_batch;        \
        dtype* dst_ptr = dst.compatible_ptr<dtype>();                         \
        for (int chan = 0; chan < dst_channel; ++chan) {                      \
            dtype bias_val = bias.compatible_ptr<dtype>()[chan];              \
            for (int i = 0; i < chann_stride; ++i) {                          \
                int idx =                                                     \
                        (chan / (_pack_size)) * chann_stride * (_pack_size) + \
                        i * (_pack_size) + chan % (_pack_size);               \
                dst_ptr[idx] = static_cast<comp_type>(dst_ptr[idx]) +         \
                               static_cast<comp_type>(bias_val);              \
            }                                                                 \
        }                                                                     \
    } while (0)
            case Format::CHWN4: {
                BIAS_ADD_CHWNx(4);
                break;
            }
            case Format::NHWC: {
                int dst_nhw = dst.layout.shape[0] * dst.layout.shape[1] *
                              dst.layout.shape[2];
                int dst_channel = dst.layout.shape[3];
                dtype* dst_ptr = dst.compatible_ptr<dtype>();

                for (int nhw = 0; nhw < dst_nhw; ++nhw) {
                    for (int chan = 0; chan < dst_channel; ++chan, ++dst_ptr) {
                        dtype bias_val = bias.compatible_ptr<dtype>()[chan];
                        comp_type val = static_cast<comp_type>(dst_ptr[0]) +
                                        static_cast<comp_type>(bias_val);
                        dst_ptr[0] = val;
                    }
                }
                break;
            };
            case Format::NHWCD4: {
                dtype* bias_ptr = bias.compatible_ptr<dtype>();
                dtype* dst_ptr = dst.compatible_ptr<dtype>();
                for (size_t n = 0; n < dst.layout[0]; n++) {
                    for (size_t h = 0; h < dst.layout[1]; h++) {
                        for (size_t cb = 0; cb < dst.layout[2]; cb++) {
                            for (size_t w = 0; w < dst.layout[3]; w++) {
                                for (size_t i = 0; i < 4; i++) {
                                    auto ptr = dst_ptr +
                                               n * dst.layout.stride[0] +
                                               h * dst.layout.stride[1] +
                                               cb * dst.layout.stride[2] +
                                               w * dst.layout.stride[3] +
                                               i * dst.layout.stride[4];
                                    comp_type val =
                                            static_cast<comp_type>(ptr[0]) +
                                            static_cast<comp_type>(
                                                    bias_ptr[cb * 4 + i]);
                                    ptr[0] = val;
                                }
                            }
                        }
                    }
                }
                break;
            };
            default:
                megdnn_assert_internal(0);
        }
    }
}

}  // namespace convolution
}  // namespace naive
}  // namespace megdnn

// vim: syntax=cpp.doxygen