helper.h 39.4 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/naive/convolution/helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
 */
#pragma once

#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"

#include <cstring>

namespace megdnn {
namespace naive {
namespace convolution {

struct GroupCounter {
    const size_t grp_size;
    size_t cur_grp = 0, cur_off = 0;

    explicit GroupCounter(size_t grp_size) : grp_size{grp_size} {}

    void next() {
        if ((++cur_off) == grp_size) {
            cur_off = 0;
            ++cur_grp;
        }
    }
};

struct StrategyFwd {
    template <typename st, typename ft, typename ct>
    static void on(st& s, ft& f, ct& d, DType, DType, DType) {
        d += static_cast<ct>(s) * static_cast<ct>(f);
    }

    template <typename ct, typename dt>
    static void write(ct& d, dt& dst) {
        dst = static_cast<dt>(d);
    }

    template <typename dt>
    static void init_dval(dt& d) {
        d = static_cast<dt>(0);
    }
};

// Explicit specialization of member function template is not allowed to happen
// in class scope, this is a defect of C++ specification which will be fixed in
// C++17. We workaround this by marking the implmentation as inline and move
// out of class definition.
template <>
M
Megvii Engine Team 已提交
59 60
inline void StrategyFwd::on(
        dt_quint8& s, dt_quint8& f, dt_qint32& d, DType src_dt, DType filt_dt, DType) {
61
    auto cast = [](const dt_quint8& val, DType dt) {
M
Megvii Engine Team 已提交
62 63 64
        return dt_qint32(
                static_cast<int32_t>(val.as_uint8()) -
                dt.param<dtype::Quantized8Asymm>().zero_point);
65 66 67 68
    };
    d += cast(s, src_dt) * cast(f, filt_dt);
}

69
template <>
M
Megvii Engine Team 已提交
70 71
inline void StrategyFwd::on(
        dt_qint8& s, dt_qint8& f, dt_float32& d, DType src_dt, DType filt_dt, DType) {
72 73 74 75 76 77
    auto cast = [](const dt_qint8& val, DType dt) {
        return dt.param<dtype::QuantizedS8>().dequantize(val);
    };
    d += cast(s, src_dt) * cast(f, filt_dt);
}

78
template <>
M
Megvii Engine Team 已提交
79 80
inline void StrategyFwd::on(
        dt_qint8& s, dt_qint8& f, dt_qint32& d, DType, DType, DType) {
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    auto cast = [](const dt_qint8& val) {
        return dt_qint32(static_cast<int32_t>(val.as_int8()));
    };
    d += cast(s) * cast(f);
}

struct StrategyBwdData {
    template <typename st, typename ft, typename dt>
    static void on(st& s, ft& f, dt& d, DType, DType, DType) {
        s += static_cast<st>(f) * static_cast<st>(d);
    }

    template <typename ct, typename dt>
    static void write(ct&, dt&) {}

    template <typename dt>
    static void init_dval(dt&) {}
};

template <>
M
Megvii Engine Team 已提交
101 102
inline void StrategyBwdData::on(
        int& s, signed char& f, signed char& d, DType, DType, DType) {
103 104 105 106 107 108 109
    auto cast = [](signed char& val) {
        return static_cast<int32_t>(((megdnn::dt_qint8)val).as_int8());
    };
    s += cast(f) * cast(d);
}

template <>
M
Megvii Engine Team 已提交
110 111
inline void StrategyBwdData::on(
        dt_qint32& s, dt_quint8& f, dt_quint8& d, DType, DType filt_dt, DType dst_dt) {
112
    auto cast = [](const dt_quint8& val, DType dt) {
M
Megvii Engine Team 已提交
113 114 115
        return dt_qint32(
                static_cast<int32_t>(val.as_uint8()) -
                dt.param<dtype::Quantized8Asymm>().zero_point);
116 117 118 119 120
    };
    s += cast(f, filt_dt) * cast(d, dst_dt);
}

template <>
M
Megvii Engine Team 已提交
121 122
inline void StrategyBwdData::on(
        dt_qint32& s, dt_qint8& f, dt_qint8& d, DType, DType, DType) {
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    auto cast = [](const dt_qint8& val) {
        return dt_qint32(static_cast<int32_t>(val.as_int8()));
    };
    s += cast(f) * cast(d);
}

struct StrategyBwdFlt {
    template <typename st, typename ft, typename dt>
    static void on(st& s, ft& f, dt& d, DType, DType, DType) {
        f += static_cast<ft>(s) * static_cast<ft>(d);
    }

    template <typename ct, typename dt>
    static void write(ct&, dt&) {}

    template <typename dt>
    static void init_dval(dt&) {}
};

struct ConvFilterVisitor {
    template <typename ftype>
M
Megvii Engine Team 已提交
144 145 146
    static ftype* get_current_ptr(
            ftype* fptr, size_t /* batch */, size_t /* oc */, size_t /* oh */,
            size_t /* ow */, size_t /* filter_sizes*/) {
147 148 149 150
        return fptr;
    }
};

M
Megvii Engine Team 已提交
151 152 153 154 155 156
template <
        typename stype, typename ftype, typename dtype, typename comp_type,
        class Strategy, typename FilterMeta, typename FilterVisitor = ConvFilterVisitor>
void compute2d(
        _megdnn_tensor_in src, ftype* __restrict fptr, _megdnn_tensor_out dst,
        const FilterMeta& filter_meta) {
157 158
    size_t spatial_start, channel_pos, batch_pos;
    using Format = param::Convolution::Format;
M
Megvii Engine Team 已提交
159
    if (filter_meta.format == Format::NCHW || filter_meta.format == Format::NCHW88 ||
160
        filter_meta.format == Format::NCHW44 ||
161
        filter_meta.format == Format::NCHW44_DOT ||
162
        filter_meta.format == Format::NCHW4 ||
163
        filter_meta.format == Format::NCHW4_NCHW ||
M
Megvii Engine Team 已提交
164
        filter_meta.format == Format::NCHW4_NHWC ||
165
        filter_meta.format == Format::NCHW4_NCHW32 ||
M
Megvii Engine Team 已提交
166
        filter_meta.format == Format::NCHW8 || filter_meta.format == Format::NCHW32 ||
167 168
        filter_meta.format == Format::NCHW32_NCHW4 ||
        filter_meta.format == Format::NCHW64) {
169 170 171 172 173 174 175 176
        spatial_start = 2;
        channel_pos = 1;
        batch_pos = 0;
    } else if (filter_meta.format == Format::CHWN4) {
        spatial_start = 1;
        channel_pos = 0;
        batch_pos = 3;
    } else {
M
Megvii Engine Team 已提交
177
        megdnn_assert(filter_meta.format == Format::NHWC, "invalid conv format");
178 179 180 181 182 183 184 185
        spatial_start = 1;
        channel_pos = 3;
        batch_pos = 0;
    }

    auto N = src.layout.shape[batch_pos], IH = src.layout.shape[spatial_start],
         IW = src.layout.shape[spatial_start + 1];
    auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1];
186 187
    size_t OC, OH, OW;
    if (filter_meta.format == Format::NCHW4_NHWC) {
M
Megvii Engine Team 已提交
188
        OC = dst.layout.shape[3], OH = dst.layout.shape[1], OW = dst.layout.shape[2];
189
    } else {
M
Megvii Engine Team 已提交
190
        OC = dst.layout.shape[channel_pos], OH = dst.layout.shape[spatial_start],
191 192
        OW = dst.layout.shape[spatial_start + 1];
    }
193

M
Megvii Engine Team 已提交
194
    if (filter_meta.format == Format::NCHW4 || filter_meta.format == Format::CHWN4 ||
195
        filter_meta.format == Format::NCHW44_DOT ||
M
Megvii Engine Team 已提交
196
        filter_meta.format == Format::NCHW44 ||
197
        filter_meta.format == Format::NCHW32_NCHW4) {
198
        OC *= 4;
M
Megvii Engine Team 已提交
199 200 201
    } else if (
            filter_meta.format == Format::NCHW8 ||
            filter_meta.format == Format::NCHW88) {
202
        OC *= 8;
M
Megvii Engine Team 已提交
203 204 205
    } else if (
            filter_meta.format == Format::NCHW32 ||
            filter_meta.format == Format::NCHW4_NCHW32) {
206
        OC *= 32;
207 208
    } else if (filter_meta.format == Format::NCHW64) {
        OC *= 64;
209 210 211
    }

    size_t FS_G, FS_OC, FS_IC, FS_SPATIAL;
M
Megvii Engine Team 已提交
212
    if (filter_meta.format == Format::NCHW || filter_meta.format == Format::NCHW4 ||
213
        filter_meta.format == Format::NCHW4_NCHW ||
M
Megvii Engine Team 已提交
214
        filter_meta.format == Format::NCHW4_NHWC ||
215
        filter_meta.format == Format::NCHW4_NCHW32 ||
M
Megvii Engine Team 已提交
216
        filter_meta.format == Format::NCHW8 || filter_meta.format == Format::NCHW32 ||
217 218
        filter_meta.format == Format::NCHW32_NCHW4 ||
        filter_meta.format == Format::NCHW64) {
219 220 221 222 223 224 225 226 227 228 229 230
        // g, oc, ic, fh, fw
        FS_SPATIAL = 1;
        FS_IC = FH * FW;
        FS_OC = FS_IC * filter_meta.icpg;
        FS_G = FS_OC * filter_meta.ocpg;
    } else if (filter_meta.format == Format::CHWN4) {
        // g, ic, fh, fw, oc, pack_size
        FS_SPATIAL = filter_meta.ocpg * 4;
        FS_IC = FH * FW * FS_SPATIAL;
        FS_OC = 4;
        FS_G = FS_IC * filter_meta.icpg;
    } else if (filter_meta.format == Format::NCHW88) {
M
Megvii Engine Team 已提交
231 232
        if (filter_meta.group > 1 && filter_meta.icpg == 1 && src.layout.ndim == 5 &&
            filter_meta.ocpg == 1) {
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
            FS_SPATIAL = 8;
            FS_IC = FH * FW * FS_SPATIAL;
            FS_OC = FS_IC * filter_meta.icpg;
            FS_G = FS_OC * filter_meta.ocpg;
        } else {
            if (src.layout.ndim == 4 && dst.layout.ndim == 5) {
                FS_IC = 8;
                FS_SPATIAL = filter_meta.icpg * FS_IC;
                FS_OC = FH * FW * FS_SPATIAL;
                FS_G = FS_OC * filter_meta.ocpg / 8;
            } else {
                FS_SPATIAL = 8 * 8;
                FS_IC = FH * FW * FS_SPATIAL;
                FS_OC = FS_IC * filter_meta.icpg / 8;
                FS_G = FS_OC * filter_meta.ocpg / 8;
            }
        }
M
Megvii Engine Team 已提交
250 251 252 253 254
    } else if (
            filter_meta.format == Format::NCHW44 ||
            filter_meta.format == Format::NCHW44_DOT) {
        if (filter_meta.group > 1 && filter_meta.icpg == 1 && src.layout.ndim == 5 &&
            filter_meta.ocpg == 1) {
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
            FS_SPATIAL = 4;
            FS_IC = FH * FW * FS_SPATIAL;
            FS_OC = FS_IC * filter_meta.icpg;
            FS_G = FS_OC * filter_meta.ocpg;
        } else {
            if (src.layout.ndim == 4 && dst.layout.ndim == 5) {
                FS_IC = 4;
                FS_SPATIAL = filter_meta.icpg * FS_IC;
                FS_OC = FH * FW * FS_SPATIAL;
                FS_G = FS_OC * filter_meta.ocpg / 4;
            } else {
                FS_SPATIAL = 4 * 4;
                FS_IC = FH * FW * FS_SPATIAL;
                FS_OC = FS_IC * filter_meta.icpg / 4;
                FS_G = FS_OC * filter_meta.ocpg / 4;
            }
        }
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
    } else {
        // g, oc, fh, fw, ic
        megdnn_assert(filter_meta.format == Format::NHWC);
        FS_IC = 1;
        FS_SPATIAL = filter_meta.icpg;
        FS_OC = FS_SPATIAL * FH * FW;
        FS_G = FS_OC * filter_meta.ocpg;
    }
    int ph = filter_meta.padding[0], pw = filter_meta.padding[1];
    size_t sh = filter_meta.stride[0], sw = filter_meta.stride[1];
    int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1];
    stype* __restrict sptr = src.compatible_ptr<stype>();
    dtype* __restrict dptr = dst.compatible_ptr<dtype>();

    int h_offset = -ph, w_offset = -pw;
    if (filter_meta.should_flip) {
        h_offset += filter_meta.dilated_spatial[0] - 1;
        w_offset += filter_meta.dilated_spatial[1] - 1;
        dh = -dh;
        dw = -dw;
    }

M
Megvii Engine Team 已提交
294 295 296 297
    auto get_linear_addr = [&filter_meta, &src](
                                   ptrdiff_t n, ptrdiff_t c, ptrdiff_t h, ptrdiff_t w,
                                   const TensorLayout& layout,
                                   bool is_output) -> ptrdiff_t {
298
        if (filter_meta.format == Format::NCHW) {
M
Megvii Engine Team 已提交
299 300
            return n * layout.stride[0] + c * layout.stride[1] + h * layout.stride[2] +
                   w * layout.stride[3];
301
        } else if (filter_meta.format == Format::NHWC) {
M
Megvii Engine Team 已提交
302 303 304 305 306
            return n * layout.stride[0] + h * layout.stride[1] + w * layout.stride[2] +
                   c * layout.stride[3];
        } else if (
                filter_meta.format == Format::NCHW8 ||
                filter_meta.format == Format::NCHW88) {
307 308 309 310 311 312 313 314 315
            if (filter_meta.format == Format::NCHW88 && !is_output &&
                src.layout.ndim == 4) {
                return n * layout.stride[0] + c * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3];
            } else {
                return n * layout.stride[0] + (c / 8) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c & 0b111) * layout.stride[4];
            }
M
Megvii Engine Team 已提交
316 317 318
        } else if (
                filter_meta.format == Format::NCHW44 ||
                filter_meta.format == Format::NCHW44_DOT) {
319
            if (!is_output && src.layout.ndim == 4) {
320 321 322 323 324 325 326
                return n * layout.stride[0] + c * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3];
            } else {
                return n * layout.stride[0] + (c / 4) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c % 4) * layout.stride[4];
            }
327 328 329 330
        } else if (filter_meta.format == Format::NCHW32) {
            return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
                   h * layout.stride[2] + w * layout.stride[3] +
                   (c & 0x1F) * layout.stride[4];
331 332 333 334 335 336 337 338 339 340
        } else if (filter_meta.format == Format::NCHW32_NCHW4) {
            if (is_output) {
                return n * layout.stride[0] + (c / 4) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c & 0b11) * layout.stride[4];
            } else {
                return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c & 0x1F) * layout.stride[4];
            }
341 342 343 344
        } else if (filter_meta.format == Format::CHWN4) {
            return (c / 4) * layout.stride[0] + h * layout.stride[1] +
                   w * layout.stride[2] + n * layout.stride[3] +
                   (c % 4) * layout.stride[4];
345 346 347 348 349 350 351 352 353
        } else if (filter_meta.format == Format::NCHW4_NCHW) {
            if (is_output) {
                return n * layout.stride[0] + c * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3];
            } else {
                return n * layout.stride[0] + (c / 4) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c & 0b11) * layout.stride[4];
            }
354 355 356 357 358 359 360 361 362
        } else if (filter_meta.format == Format::NCHW4_NHWC) {
            if (is_output) {
                return n * layout.stride[0] + h * layout.stride[1] +
                       w * layout.stride[2] + c * layout.stride[3];
            } else {
                return n * layout.stride[0] + (c / 4) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c & 0b11) * layout.stride[4];
            }
363 364 365 366 367 368 369 370 371 372
        } else if (filter_meta.format == Format::NCHW4_NCHW32) {
            if (is_output) {
                return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c & 0x1F) * layout.stride[4];
            } else {
                return n * layout.stride[0] + (c / 4) * layout.stride[1] +
                       h * layout.stride[2] + w * layout.stride[3] +
                       (c & 0b11) * layout.stride[4];
            }
373 374 375 376
        } else if (filter_meta.format == Format::NCHW64) {
            return n * layout.stride[0] + (c >> 6) * layout.stride[1] +
                   h * layout.stride[2] + w * layout.stride[3] +
                   (c & 0x3F) * layout.stride[4];
377
        } else {
M
Megvii Engine Team 已提交
378
            megdnn_assert(filter_meta.format == Format::NCHW4, "invalid conv format");
379 380 381 382 383 384
            return n * layout.stride[0] + (c / 4) * layout.stride[1] +
                   h * layout.stride[2] + w * layout.stride[3] +
                   (c & 0b11) * layout.stride[4];
        }
    };

M
Megvii Engine Team 已提交
385 386
    auto get_filter_addr = [&](GroupCounter& gc_out, size_t ic, size_t ic0, size_t fh,
                               size_t fw) {
387 388
        if (filter_meta.format == Format::NCHW4 ||
            filter_meta.format == Format::NCHW4_NCHW ||
389
            filter_meta.format == Format::NCHW4_NHWC ||
390
            filter_meta.format == Format::NCHW4_NCHW32) {
391
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
M
Megvii Engine Team 已提交
392 393
                   (ic - ic0) / 4 * FS_IC * 4 + (fh * FW + fw) * FS_SPATIAL * 4 +
                   ((ic - ic0) & 0b11);
394 395
        } else if (filter_meta.format == Format::NCHW8) {
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
M
Megvii Engine Team 已提交
396 397 398 399 400
                   (ic - ic0) / 8 * FS_IC * 8 + (fh * FW + fw) * FS_SPATIAL * 8 +
                   ((ic - ic0) & 0b111);
        } else if (
                filter_meta.format == Format::NCHW32 ||
                filter_meta.format == Format::NCHW32_NCHW4) {
401
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
M
Megvii Engine Team 已提交
402 403
                   (ic - ic0) / 32 * FS_IC * 32 + (fh * FW + fw) * FS_SPATIAL * 32 +
                   ((ic - ic0) & 0x1F);
404 405 406 407
        } else if (filter_meta.format == Format::CHWN4) {
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
                   (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
                   ((ic - ic0) % 4);
M
Megvii Engine Team 已提交
408 409 410
        } else if (
                filter_meta.format == Format::NCHW88 ||
                filter_meta.format == Format::NCHW44) {
411
            size_t pack_c_size = 4_z;
M
Megvii Engine Team 已提交
412
            if (filter_meta.format == Format::NCHW88) {
413 414
                pack_c_size = 8_z;
            }
415 416
            if (src.layout.ndim == 4) {
                // ic < 8, input is nchw
M
Megvii Engine Team 已提交
417
                return gc_out.cur_grp * FS_G + gc_out.cur_off / pack_c_size * FS_OC +
418
                       (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
419
                       gc_out.cur_off % pack_c_size;
M
Megvii Engine Team 已提交
420 421 422
            } else if (
                    filter_meta.group > 1 && filter_meta.icpg == 1 &&
                    filter_meta.ocpg == 1 && src.layout.ndim == 5) {
423
                // dw case
M
Megvii Engine Team 已提交
424 425
                return gc_out.cur_grp / pack_c_size * FS_G + gc_out.cur_off * FS_OC +
                       (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL +
426
                       gc_out.cur_grp % pack_c_size;
427 428
            } else if (src.layout.ndim == 5) {
                // normal case
M
Megvii Engine Team 已提交
429 430
                return gc_out.cur_grp * FS_G + gc_out.cur_off / pack_c_size * FS_OC +
                       (ic - ic0) / pack_c_size * FS_IC + (fh * FW + fw) * FS_SPATIAL +
431 432
                       ((ic - ic0) % pack_c_size) * pack_c_size +
                       gc_out.cur_off % pack_c_size;
433
            } else {
434 435 436
                megdnn_throw(
                        "nchw88/nchw44 naive not support this input and "
                        "output\n");
437
            }
438
        } else if (filter_meta.format == Format::NCHW44_DOT) {
439
            if (src.layout.ndim == 4) {
440
                // ic < 4, input is nchw
441 442 443
                return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
                       (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
                       gc_out.cur_off % 4;
M
Megvii Engine Team 已提交
444 445 446
            } else if (
                    filter_meta.group > 1 && filter_meta.icpg == 1 &&
                    filter_meta.ocpg == 1 && src.layout.ndim == 5) {
447 448 449 450 451 452 453 454
                // dw case
                return gc_out.cur_grp / 4 * FS_G + gc_out.cur_off * FS_OC +
                       (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL +
                       gc_out.cur_grp % 4;
            } else if (src.layout.ndim == 5) {
                // normal case
                return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
                       (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
455
                       (gc_out.cur_off % 4) * 4 + ((ic - ic0) % 4);
456
            } else {
M
Megvii Engine Team 已提交
457
                megdnn_throw("nchw44_dot naive not support this input and output\n");
458
            }
459 460
        } else if (filter_meta.format == Format::NCHW64) {
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
M
Megvii Engine Team 已提交
461 462
                   (ic - ic0) / 64 * FS_IC * 64 + (fh * FW + fw) * FS_SPATIAL * 64 +
                   ((ic - ic0) & 0x3F);
463
        } else {
M
Megvii Engine Team 已提交
464 465
            return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + (ic - ic0) * FS_IC +
                   (fh * FW + fw) * FS_SPATIAL;
466 467 468 469 470 471 472 473
        }
    };
    size_t filter_sizes = filter_meta.ocpg * filter_meta.icpg * FH * FW;
    for (size_t n = 0; n < N; ++n) {
        GroupCounter gc_out{filter_meta.ocpg};
        for (size_t oc = 0; oc < OC; ++oc, gc_out.next())
            for (size_t oh = 0; oh < OH; ++oh)
                for (size_t ow = 0; ow < OW; ++ow) {
M
Megvii Engine Team 已提交
474 475
                    comp_type dval =
                            dptr[get_linear_addr(n, oc, oh, ow, dst.layout, true)];
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
                    ftype* fptr_cur = FilterVisitor::template get_current_ptr(
                            fptr, n, oc, oh, ow, filter_sizes);
                    Strategy::init_dval(dval);

                    for (size_t fh = 0; fh < FH; ++fh)
                        for (size_t fw = 0; fw < FW; ++fw) {
                            size_t ih = sh * oh + fh * dh + h_offset,
                                   iw = sw * ow + fw * dw + w_offset;
                            // here ih and iw are represented in unsigned int
                            // they will become very large if underflow occurs
                            if (ih < IH && iw < IW) {
                                size_t ic0 = gc_out.cur_grp * filter_meta.icpg,
                                       ic1 = ic0 + filter_meta.icpg;
                                for (size_t ic = ic0; ic < ic1; ++ic) {
                                    stype& sval = sptr[get_linear_addr(
                                            n, ic, ih, iw, src.layout, false)];
                                    ftype& fval = fptr_cur[get_filter_addr(
                                            gc_out, ic, ic0, fh, fw)];
M
Megvii Engine Team 已提交
494 495 496
                                    Strategy::on(
                                            sval, fval, dval, src.layout.dtype,
                                            filter_meta.dtype, dst.layout.dtype);
497 498 499
                                }
                            }
                        }
M
Megvii Engine Team 已提交
500 501 502
                    Strategy::write(
                            dval,
                            dptr[get_linear_addr(n, oc, oh, ow, dst.layout, true)]);
503 504 505 506
                }
    }
}

M
Megvii Engine Team 已提交
507 508 509 510 511 512
template <
        typename stype, typename ftype, typename dtype, typename comp_type,
        class Strategy, typename FilterMeta, typename FilterVisitor = ConvFilterVisitor>
void compute2d_hwcd4(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
        const FilterMeta& filter_meta) {
513 514 515 516 517 518 519 520 521 522 523 524
    // The filter's layout is (G, OC/4, FH, FW, IC, 4) when using mad
    // and (G, OC/4, FH, FW, IC/4, 4, 4) when using dot.
    bool use_dot = false;
    if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
        src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
        (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 &&
         (filter.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
          filter.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm)))
        use_dot = true;

    using Format = param::Convolution::Format;
    megdnn_assert(filter_meta.format == Format::NHWCD4);
M
Megvii Engine Team 已提交
525
    auto N = src.layout.shape[0], IH = src.layout.shape[1], IW = src.layout.shape[3];
526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
    auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1];
    auto OC = dst.layout.shape[2] * 4, OH = dst.layout.shape[1],
         OW = dst.layout.shape[3];
    int ph = filter_meta.padding[0], pw = filter_meta.padding[1];
    size_t sh = filter_meta.stride[0], sw = filter_meta.stride[1];
    int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1];
    stype* __restrict sptr = src.compatible_ptr<stype>();
    ftype* __restrict fptr = filter.compatible_ptr<ftype>();
    dtype* __restrict dptr = dst.compatible_ptr<dtype>();

    megdnn_assert(!filter_meta.should_flip);
    int h_offset = -ph, w_offset = -pw;

    auto get_linear_addr = [](size_t n, size_t c, size_t h, size_t w,
                              const TensorLayout& layout) -> size_t {
        return n * layout.stride[0] + h * layout.stride[1] +
               (c / 4) * layout.stride[2] + w * layout.stride[3] +
               c % 4 * layout.stride[4];
    };

    size_t FS_G, FS_OCB, FS_SPATIAL;
    if (!use_dot && filter.layout.ndim == 5) {
        if (filter_meta.ocpg == 1 && filter_meta.icpg == 1) {
            // chanwise conv, (G/4, 1, FH, FW, 4)
            FS_G = filter.layout.stride[0];
            FS_OCB = 0;
            FS_SPATIAL = 4;
        } else {
            // dense conv, (OC/4, FH, FW, IC, 4)
            FS_G = 0;
            FS_OCB = filter.layout.stride[0];
            FS_SPATIAL = filter.layout.stride[2];
        }
    } else if (!use_dot && filter.layout.ndim == 6) {
        // group conv, (G, OC/4, FH, FW, IC, 4)
        FS_G = filter.layout.stride[0];
        FS_OCB = filter.layout.stride[1];
        FS_SPATIAL = filter.layout.stride[3];
    } else if (use_dot && filter.layout.ndim == 6) {
        // dense conv used dot, (OC/4, FH, FW, IC/4, 4, 4)
        FS_G = 0;
        FS_OCB = filter.layout.stride[0];
        FS_SPATIAL = filter.layout.stride[2];
    } else if (use_dot && filter.layout.ndim == 7) {
        // group conv used dot, (G, OC/4, FH, FW, IC/4, 4, 4)
        FS_G = filter.layout.stride[0];
        FS_OCB = filter.layout.stride[1];
        FS_SPATIAL = filter.layout.stride[3];
M
Megvii Engine Team 已提交
574 575 576
    } else if (
            use_dot && filter.layout.ndim == 5 && filter_meta.ocpg == 1 &&
            filter_meta.icpg == 1) {
577 578 579 580 581 582 583 584
        // chanwise conv, (G/4, 1, FH, FW, 4)
        FS_G = filter.layout.stride[0];
        FS_OCB = 0;
        FS_SPATIAL = 4;
    } else {
        megdnn_assert(0, "invalid filter layout");
    }

M
Megvii Engine Team 已提交
585 586 587
    auto get_filter_addr = [&use_dot, &FS_G, &FS_OCB, &FS_SPATIAL, &FW, &filter_meta](
                                   size_t group, size_t offset, size_t fh, size_t fw,
                                   size_t c) -> size_t {
588
        if (filter_meta.ocpg == 1 && filter_meta.icpg == 1) {
M
Megvii Engine Team 已提交
589
            return (group / 4) * FS_G + (fh * FW + fw) * FS_SPATIAL + (group % 4);
590
        } else if (!use_dot) {
M
Megvii Engine Team 已提交
591 592
            return group * FS_G + (offset / 4) * FS_OCB + (fh * FW + fw) * FS_SPATIAL +
                   c * 4 + (offset % 4);
593 594
        } else {
            megdnn_assert(use_dot);
M
Megvii Engine Team 已提交
595 596
            return group * FS_G + (offset / 4) * FS_OCB + (fh * FW + fw) * FS_SPATIAL +
                   (c / 4) * 16 + (offset % 4) * 4 + (c % 4);
597 598 599 600 601 602 603 604 605
        }
    };

    size_t filter_sizes = filter_meta.ocpg * filter_meta.icpg * FH * FW;
    for (size_t n = 0; n < N; ++n) {
        GroupCounter gc_out{filter_meta.ocpg};
        for (size_t oc = 0; oc < OC; ++oc, gc_out.next())
            for (size_t oh = 0; oh < OH; ++oh)
                for (size_t ow = 0; ow < OW; ++ow) {
M
Megvii Engine Team 已提交
606
                    comp_type dval = dptr[get_linear_addr(n, oc, oh, ow, dst.layout)];
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
                    Strategy::init_dval(dval);
                    ftype* fptr_cur = FilterVisitor::template get_current_ptr(
                            fptr, n, oc, oh, ow, filter_sizes);

                    for (size_t fh = 0; fh < FH; ++fh)
                        for (size_t fw = 0; fw < FW; ++fw) {
                            size_t ih = sh * oh + fh * dh + h_offset,
                                   iw = sw * ow + fw * dw + w_offset;
                            // here ih and iw are represented in unsigned int
                            // they will become very large if underflow occurs
                            if (ih < IH && iw < IW) {
                                size_t ic0 = gc_out.cur_grp * filter_meta.icpg,
                                       ic1 = ic0 + filter_meta.icpg;
                                for (size_t ic = ic0; ic < ic1; ++ic) {
                                    stype& sval = sptr[get_linear_addr(
                                            n, ic, ih, iw, src.layout)];
                                    ftype& fval = fptr_cur[get_filter_addr(
M
Megvii Engine Team 已提交
624 625 626 627 628
                                            gc_out.cur_grp, gc_out.cur_off, fh, fw,
                                            ic - ic0)];
                                    Strategy::on(
                                            sval, fval, dval, src.layout.dtype,
                                            filter_meta.dtype, dst.layout.dtype);
629 630 631 632
                                }
                            }
                        }
                    Strategy::write(
M
Megvii Engine Team 已提交
633
                            dval, dptr[get_linear_addr(n, oc, oh, ow, dst.layout)]);
634 635 636 637 638 639
                }
    }
}

//! forward with only filter ptr
template <typename stype, typename ftype, typename dtype, typename comp_type>
M
Megvii Engine Team 已提交
640 641 642
void forward(
        _megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
        const Convolution::CanonizedFilterMeta& filter_meta) {
643
    megdnn_assert(filter_meta.spatial_ndim == 2);
644 645 646 647 648 649 650 651 652 653
    megdnn_assert(
            filter_meta.format == param::Convolution::Format::NCHW ||
            filter_meta.format == param::Convolution::Format::NHWC ||
            filter_meta.format == param::Convolution::Format::NCHW88 ||
            filter_meta.format == param::Convolution::Format::NCHW44 ||
            filter_meta.format == param::Convolution::Format::NCHW44_DOT ||
            filter_meta.format == param::Convolution::Format::NCHW4 ||
            filter_meta.format == param::Convolution::Format::NCHW4_NCHW ||
            filter_meta.format == param::Convolution::Format::NCHW4_NCHW32 ||
            filter_meta.format == param::Convolution::Format::NCHW32_NCHW4);
654 655 656 657 658 659
    compute2d<stype, ftype, dtype, comp_type, StrategyFwd>(
            src, const_cast<ftype*>(fptr), dst, filter_meta);
}

//! forward with full filter (for API compatibility)
template <typename stype, typename ftype, typename dtype, typename comp_type>
M
Megvii Engine Team 已提交
660 661 662
void forward(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
        const Convolution::CanonizedFilterMeta& filter_meta) {
663 664 665 666 667 668 669 670 671
    if (filter_meta.format == param::Convolution::Format::NHWCD4) {
        return compute2d_hwcd4<stype, ftype, dtype, comp_type, StrategyFwd>(
                src, filter, dst, filter_meta);
    }
    return forward<stype, ftype, dtype, comp_type>(
            src, filter.compatible_ptr<ftype>(), dst, filter_meta);
}

template <typename ftype, typename dtype, typename gtype>
M
Megvii Engine Team 已提交
672 673 674
void backward_data(
        _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
        const Convolution::CanonizedFilterMeta& filter_meta) {
675 676 677 678 679 680 681 682 683 684 685
    memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
    megdnn_assert(filter_meta.spatial_ndim == 2);
    if (filter_meta.format == param::Convolution::Format::NHWCD4) {
        return compute2d_hwcd4<gtype, ftype, dtype, dtype, StrategyBwdData>(
                grad, filter, diff, filter_meta);
    }
    compute2d<gtype, ftype, dtype, dtype, StrategyBwdData>(
            grad, filter.compatible_ptr<ftype>(), diff, filter_meta);
}

template <typename stype, typename dtype, typename gtype>
M
Megvii Engine Team 已提交
686 687 688
void backward_filter(
        _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
        const Convolution::CanonizedFilterMeta& filter_meta) {
689 690 691 692 693 694
    memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
    megdnn_assert(filter_meta.spatial_ndim == 2);
    compute2d<stype, gtype, dtype, dtype, StrategyBwdFlt>(
            src, grad.compatible_ptr<gtype>(), diff, filter_meta);
}

M
Megvii Engine Team 已提交
695 696 697 698 699 700 701
template <
        typename stype, typename ftype, typename dtype, typename comp_type,
        typename FilterMeta, typename FilterVisitor = ConvFilterVisitor>
void forward_bias(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
        _megdnn_tensor_out dst, dt_byte* /* workspace_ptr */,
        const FilterMeta& filter_meta) {
702 703 704 705
    megdnn_assert(filter_meta.spatial_ndim == 2);
    switch (filter_meta.format) {
        case param::Convolution::Format::NCHW:
        case param::Convolution::Format::NCHW88:
706
        case param::Convolution::Format::NCHW44:
707
        case param::Convolution::Format::NCHW44_DOT:
708 709
        case param::Convolution::Format::NHWC:
        case param::Convolution::Format::NCHW4:
710
        case param::Convolution::Format::NCHW4_NCHW:
711
        case param::Convolution::Format::NCHW4_NHWC:
712
        case param::Convolution::Format::NCHW4_NCHW32:
713 714
        case param::Convolution::Format::NCHW8:
        case param::Convolution::Format::NCHW32:
715
        case param::Convolution::Format::NCHW32_NCHW4:
716
        case param::Convolution::Format::CHWN4:
717
        case param::Convolution::Format::NCHW64:
M
Megvii Engine Team 已提交
718 719 720 721
            compute2d<
                    stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta,
                    FilterVisitor>(
                    src, filter.compatible_ptr<ftype>(), dst, filter_meta);
722 723
            break;
        case param::Convolution::Format::NHWCD4:
M
Megvii Engine Team 已提交
724 725 726
            compute2d_hwcd4<
                    stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta,
                    FilterVisitor>(src, filter, dst, filter_meta);
727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
            break;
        default:
            megdnn_assert_internal(0);
    }

    //! we can not decide with bias.raw_ptr, as non bias the raw_ptr is not
    //! nullptr
    if (bias.layout.ndim != 0) {
        if (dst.layout.eq_shape(bias.layout) &&
            dst.layout.dtype.enumv() == bias.layout.dtype.enumv()) {
            dtype* dst_ptr = dst.compatible_ptr<dtype>();
            dtype* bias_ptr = bias.compatible_ptr<dtype>();
            for (size_t i = 0; i < dst.layout.span().dist_elem(); i++) {
                comp_type val = static_cast<comp_type>(dst_ptr[0]) +
                                static_cast<comp_type>(bias_ptr[0]);
                dst_ptr[0] = val;
                dst_ptr++;
                bias_ptr++;
            }
            return;
        }

        using Format = param::ConvBias::Format;
        switch (filter_meta.format) {
751 752
            case Format::NCHW:
            case Format::NCHW4_NCHW: {
753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769
                int dst_batch = dst.layout.shape[0];
                int dst_channel = dst.layout.shape[1];
                int chann_stride = dst.layout.shape[2] * dst.layout.shape[3];
                dtype* dst_ptr = dst.compatible_ptr<dtype>();

                for (int batch = 0; batch < dst_batch; ++batch) {
                    for (int chan = 0; chan < dst_channel; ++chan) {
                        dtype bias_val = bias.compatible_ptr<dtype>()[chan];
                        for (int i = 0; i < chann_stride; ++i, ++dst_ptr) {
                            comp_type val = static_cast<comp_type>(dst_ptr[0]) +
                                            static_cast<comp_type>(bias_val);
                            dst_ptr[0] = val;
                        }
                    }
                }
                break;
            };
M
Megvii Engine Team 已提交
770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
#define BIAS_ADD_NCHWx(_pack_size)                                                    \
    do {                                                                              \
        megdnn_assert(dst.layout.is_contiguous());                                    \
        int dst_batch = dst.layout.shape[0];                                          \
        int dst_channel = dst.layout.shape[1] * (_pack_size);                         \
        int chann_stride = dst.layout.shape[2] * dst.layout.shape[3];                 \
        dtype* dst_ptr = dst.compatible_ptr<dtype>();                                 \
        for (int batch = 0; batch < dst_batch; ++batch) {                             \
            for (int chan = 0; chan < dst_channel; ++chan) {                          \
                dtype bias_val = bias.compatible_ptr<dtype>()[chan];                  \
                for (int i = 0; i < chann_stride; ++i) {                              \
                    int idx = batch * dst_channel * chann_stride +                    \
                              (chan / (_pack_size)) * (chann_stride * (_pack_size)) + \
                              i * (_pack_size) + chan % (_pack_size);                 \
                    dst_ptr[idx] = static_cast<comp_type>(dst_ptr[idx]) +             \
                                   static_cast<comp_type>(bias_val);                  \
                }                                                                     \
            }                                                                         \
        }                                                                             \
789
    } while (0)
790
            case Format::NCHW44:
791
            case Format::NCHW44_DOT:
792
            case Format::NCHW32_NCHW4:
793 794 795 796 797 798 799 800
            case Format::NCHW4: {
                BIAS_ADD_NCHWx(4);
                break;
            };
            case Format::NCHW8: {
                BIAS_ADD_NCHWx(8);
                break;
            };
M
Megvii Engine Team 已提交
801
            case Format::NCHW4_NCHW32:
802 803 804 805 806 807 808 809
            case Format::NCHW32: {
                BIAS_ADD_NCHWx(32);
                break;
            };
            case Format::NCHW88: {
                BIAS_ADD_NCHWx(8);
                break;
            };
810 811 812 813
            case Format::NCHW64: {
                BIAS_ADD_NCHWx(64);
                break;
            };
M
Megvii Engine Team 已提交
814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829
#define BIAS_ADD_CHWNx(_pack_size)                                                \
    do {                                                                          \
        megdnn_assert(dst.layout.is_contiguous());                                \
        int dst_batch = dst.layout.shape[3];                                      \
        int dst_channel = dst.layout.shape[0] * (_pack_size);                     \
        int chann_stride = dst.layout.shape[1] * dst.layout.shape[2] * dst_batch; \
        dtype* dst_ptr = dst.compatible_ptr<dtype>();                             \
        for (int chan = 0; chan < dst_channel; ++chan) {                          \
            dtype bias_val = bias.compatible_ptr<dtype>()[chan];                  \
            for (int i = 0; i < chann_stride; ++i) {                              \
                int idx = (chan / (_pack_size)) * chann_stride * (_pack_size) +   \
                          i * (_pack_size) + chan % (_pack_size);                 \
                dst_ptr[idx] = static_cast<comp_type>(dst_ptr[idx]) +             \
                               static_cast<comp_type>(bias_val);                  \
            }                                                                     \
        }                                                                         \
830 831 832 833 834
    } while (0)
            case Format::CHWN4: {
                BIAS_ADD_CHWNx(4);
                break;
            }
M
Megvii Engine Team 已提交
835
            case Format::NCHW4_NHWC:
836
            case Format::NHWC: {
M
Megvii Engine Team 已提交
837 838
                int dst_nhw =
                        dst.layout.shape[0] * dst.layout.shape[1] * dst.layout.shape[2];
839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859
                int dst_channel = dst.layout.shape[3];
                dtype* dst_ptr = dst.compatible_ptr<dtype>();

                for (int nhw = 0; nhw < dst_nhw; ++nhw) {
                    for (int chan = 0; chan < dst_channel; ++chan, ++dst_ptr) {
                        dtype bias_val = bias.compatible_ptr<dtype>()[chan];
                        comp_type val = static_cast<comp_type>(dst_ptr[0]) +
                                        static_cast<comp_type>(bias_val);
                        dst_ptr[0] = val;
                    }
                }
                break;
            };
            case Format::NHWCD4: {
                dtype* bias_ptr = bias.compatible_ptr<dtype>();
                dtype* dst_ptr = dst.compatible_ptr<dtype>();
                for (size_t n = 0; n < dst.layout[0]; n++) {
                    for (size_t h = 0; h < dst.layout[1]; h++) {
                        for (size_t cb = 0; cb < dst.layout[2]; cb++) {
                            for (size_t w = 0; w < dst.layout[3]; w++) {
                                for (size_t i = 0; i < 4; i++) {
M
Megvii Engine Team 已提交
860
                                    auto ptr = dst_ptr + n * dst.layout.stride[0] +
861 862 863 864
                                               h * dst.layout.stride[1] +
                                               cb * dst.layout.stride[2] +
                                               w * dst.layout.stride[3] +
                                               i * dst.layout.stride[4];
M
Megvii Engine Team 已提交
865 866 867
                                    comp_type val = static_cast<comp_type>(ptr[0]) +
                                                    static_cast<comp_type>(
                                                            bias_ptr[cb * 4 + i]);
868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886
                                    ptr[0] = val;
                                }
                            }
                        }
                    }
                }
                break;
            };
            default:
                megdnn_assert_internal(0);
        }
    }
}

}  // namespace convolution
}  // namespace naive
}  // namespace megdnn

// vim: syntax=cpp.doxygen