opr_impl.cpp 32.3 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "src/fallback/convolution/opr_impl.h"
2 3 4 5 6 7 8 9 10 11 12
#include "src/common/algo_chooser.h"
#include "src/common/metahelper.h"
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"
#include "src/fallback/convolution/algos.h"
#include "src/fallback/convolution/run_conv.h"
#include "src/naive/convolution/helper.h"
#include "src/naive/handle.h"

#include "midout.h"

13 14 15 16
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
#include "src/arm_common/convolution/opr_impl.h"
#endif

17
#include <cstring>
18
#include <unordered_map>
19 20 21 22 23 24 25 26 27 28 29

MIDOUT_DECL(megdnn_fb_convbwd_float)

using namespace megdnn;
using namespace fallback;

namespace {
template <typename T>
void incr_ptr(T*& dst, ptrdiff_t delta) {
    dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
}
30

31 32 33 34 35 36
}  // namespace

class ConvolutionImpl::AlgoPack : NonCopyableObj {
    AlgoFallback algo_fallback;
    AlgoNaive algo_naive;
    SmallVector<std::unique_ptr<AlgoBase>> refhold;
37 38
    SmallVector<AlgoBase*> m_all_algos;
    AlgoBase::Mapper m_all_algos_map;
M
Megvii Engine Team 已提交
39

40 41 42 43 44
public:
    AlgoPack() {
        static CpuOprDelegationStorage<1> storage;
        auto conv_bias_opr = storage.get<ConvBias, 0>();
        auto&& conv_bias_algo =
45
                static_cast<ConvBiasImpl*>(conv_bias_opr)->get_all_packed_algo();
46 47
        for (auto&& algorithm : conv_bias_algo) {
            // fallback algo
48
            refhold.emplace_back(new AlgoDefault(algorithm));
49
            m_all_algos.emplace_back(refhold.back().get());
50 51
        }

52 53 54 55 56 57
        m_all_algos.emplace_back(&algo_fallback);
        m_all_algos.emplace_back(&algo_naive);

        for (auto&& algo : m_all_algos) {
            m_all_algos_map.emplace(algo->info().desc, algo);
        }
58
    }
59 60 61

    const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
62 63
};

64 65 66 67 68 69 70
const ConvolutionImpl::AlgoPack& ConvolutionImpl::algo_pack() {
    static AlgoPack algo_pack;
    return algo_pack;
}

SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::get_all_packed_algo() {
    return algo_pack().all_algos();
71
}
72 73 74

SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type(
        ConvAlgoTypePack target_type) {
M
Megvii Engine Team 已提交
75 76 77
    megdnn_assert(
            nr_type_contain(target_type.data_type),
            "ConvBias algo selection only support one type");
78
    SmallVector<ConvolutionImpl::AlgoBase*> algos;
79
    for (auto&& algo : get_all_packed_algo()) {
80 81 82 83 84 85 86 87 88
        auto algo_type = algo->get_algo_type();
        if (contain_data_type(algo_type.data_type, target_type.data_type) &&
            algo_type.algo_category == target_type.algo_category) {
            algos.push_back(algo);
        }
    }
    return algos;
}

89 90 91
bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
    return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
}
92

M
Megvii Engine Team 已提交
93
#define NCB_ALGO_FUNC(name, algo, param) static_cast<AlgoBase*>(algo)->name(param)
94

M
Megvii Engine Team 已提交
95 96 97 98
void ConvolutionImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
        const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
    auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter, workspace);
99
    auto&& algo = get_algorithm(fparam, workspace.size);
100
    if (!is_naive_algo(algo) &&
101
        NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
102 103
        exec_with_ncb_kern(fparam, algo);
    } else {
M
Megvii Engine Team 已提交
104 105
        naive::ConvolutionForwardImpl::exec(
                src, filter, dst, preprocessed_filter, workspace);
106 107 108
    }
}

M
Megvii Engine Team 已提交
109 110 111 112
void ConvolutionImpl::exec_preprocess(
        const TensorLayout& src_layout, _megdnn_tensor_in filter,
        const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
        _megdnn_workspace workspace) {
113 114 115
    //! exec_preprocess currently only support preprocess weights before exec,
    //! src/dst will be ignored, just set to nullptr
    TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
M
Megvii Engine Team 已提交
116
    auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter, workspace);
117 118

    //! should not pass workspace_size limit otherwise can not find match algo
119 120
    auto&& algo = get_algorithm(fparam);
    if (!is_naive_algo(algo) &&
M
Megvii Engine Team 已提交
121
        NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= workspace.size) {
122 123 124 125 126 127 128
        exec_preprocess_with_ncb_kern(fparam, algo);
    } else {
        naive::ConvolutionForwardImpl::exec_preprocess(
                src_layout, filter, dst_layout, preprocessed_filter, workspace);
    }
}

129
size_t ConvolutionImpl::get_workspace_in_bytes(
M
Megvii Engine Team 已提交
130
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
131
        const PreprocessedFilter* preprocessed_filter) {
132
    TensorLayoutArray layouts{src, filter, dst};
133
    AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
M
Megvii Engine Team 已提交
134 135
                            layouts.data(), layouts.size(),
                            &this->param(), sizeof(this->param())};
136
    auto rst = AlgorithmCache::instance().get(key);
137 138 139 140
    if (rst.policy.algo.valid()) {
        return rst.workspace;
    }

M
Megvii Engine Team 已提交
141
    auto fparam = make_ncb_kern_size_param(src, filter, dst, preprocessed_filter);
142
    auto&& algo = get_algorithm(fparam);
143 144
    if (is_naive_algo(algo)) {
        return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
145
                src, filter, dst, preprocessed_filter);
146
    } else {
147
        return NCB_ALGO_FUNC(get_workspace, algo, fparam);
148 149 150 151
    }
}

size_t ConvolutionImpl::get_preprocess_workspace_in_bytes(
M
Megvii Engine Team 已提交
152
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
153
    auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
154
    auto&& algo = get_algorithm(fparam);
155 156 157 158
    if (is_naive_algo(algo)) {
        return naive::ConvolutionForwardImpl::get_preprocess_workspace_in_bytes(
                src, filter, dst);
    } else {
159
        return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
160 161 162 163
    }
}

SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout(
M
Megvii Engine Team 已提交
164
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
165
    auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
166
    auto&& algo = get_algorithm(fparam);
167 168 169 170
    if (is_naive_algo(algo)) {
        return naive::ConvolutionForwardImpl::deduce_preprocessed_filter_layout(
                src, filter, dst);
    } else {
171
        return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
172 173 174 175
    }
}

std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
M
Megvii Engine Team 已提交
176
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
177
    auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
178 179
    auto ret = get_all_algorithms_with_ncb(fparam);
    if (ret.empty()) {
M
Megvii Engine Team 已提交
180
        return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter, dst);
181 182 183 184
    }
    return ret;
}

185
std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_safe(
M
Megvii Engine Team 已提交
186 187
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
    auto ret_safe = ConvolutionImpl::get_all_algorithms(src, filter, dst);
188 189 190
    return ret_safe;
}

191
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
M
Megvii Engine Team 已提交
192 193
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
        size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
194
        const AlgoAttribute& negative_attr) {
195
    auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
196
    auto result = get_algorithm_heuristic_with_ncb(
197
            fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
198 199
    if (result == nullptr) {
        result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
200 201
                src, filter, dst, workspace_limit_in_bytes, positive_attr,
                negative_attr);
202 203 204 205 206
    }
    return result;
}

ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
M
Megvii Engine Team 已提交
207
        const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
208
        const PreprocessedFilter* preprocessed_filter) {
209
    auto safe_u32 = [](size_t v) -> uint32_t {
M
Megvii Engine Team 已提交
210 211
        megdnn_assert(
                v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
212 213 214 215 216
        return v;
    };
    size_t spatial_pos;
    if (param().format == Param::Format::NCHW88 ||
        param().format == Param::Format::NCHW8 ||
217
        param().format == Param::Format::NCHW4 ||
218
        param().format == Param::Format::NCHW44_DOT ||
219
        param().format == Param::Format::NCHW44) {
220
        spatial_pos = 2;
221
    } else if (param().format == Param::Format::NCHW) {
222 223 224 225
        spatial_pos = 2;
    } else if (param().format == Param::Format::NHWC) {
        spatial_pos = 1;
    } else {
M
Megvii Engine Team 已提交
226
        megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format));
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
    }
    size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
                                ->megcore_dispatcher()
                                ->nr_threads();

    return {safe_u32(src[0]),
            {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
            {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
            check_layout_fwd(src, filter, dst),
            src.dtype,
            filter.dtype,
            dst.dtype,
            src.stride[0],
            dst.stride[0],
            {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
            {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
            param().compute_mode,
244
            nr_threads,
245 246
            preprocessed_filter,
            handle()};
247 248 249 250
}

ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
M
Megvii Engine Team 已提交
251
        const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
252
    NCBKernParam ret;
253 254
    static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
            src.layout, filter.layout, dst.layout, preprocessed_filter);
255 256 257
    ret.src_ptr = src.get_ref_ptr();
    ret.filter_ptr = filter.get_ref_ptr();
    ret.dst_ptr = dst.get_ref_ptr();
258 259 260 261 262
    ret.workspace_ptr = workspace.raw_ptr;
    ret.workspace_size = workspace.size;
    return ret;
}

M
Megvii Engine Team 已提交
263 264
void ConvolutionImpl::exec_preprocess_with_ncb_kern(
        const NCBKernParam& param, Algorithm* algo) {
265 266 267
    auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
    auto&& fallback_handle = handle();
    for (auto&& kernel : kerns) {
268 269 270 271
        megdnn_assert(
                param.filter_meta.format == Param::Format::NCHW ||
                        param.filter_meta.format == Param::Format::NHWC ||
                        param.filter_meta.format == Param::Format::NCHW88 ||
272 273
                        param.filter_meta.format == Param::Format::NCHW44 ||
                        param.filter_meta.format == Param::Format::NCHW44_DOT,
274 275 276 277 278 279 280 281 282 283
                "invalid conv format");
        auto run = [param, kernel](size_t index, size_t thread_id) {
            CpuNDRange ndrange_id(kernel.global_size, index);
            kernel.kern(param, {thread_id, ndrange_id});
        };
        static_cast<naive::HandleImpl*>(fallback_handle)
                ->dispatch_kern(run, kernel.global_size.total_size());
    }
}

M
Megvii Engine Team 已提交
284
void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo) {
285 286 287
    auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
    auto&& fallback_handle = handle();
    for (auto&& kernel : kerns) {
288 289 290 291
        megdnn_assert(
                param.filter_meta.format == Param::Format::NCHW ||
                        param.filter_meta.format == Param::Format::NHWC ||
                        param.filter_meta.format == Param::Format::NCHW88 ||
292 293
                        param.filter_meta.format == Param::Format::NCHW44 ||
                        param.filter_meta.format == Param::Format::NCHW44_DOT,
294
                "invalid conv format");
295
        auto run = [param, kernel](size_t index, size_t thread_id) {
296
            CpuNDRange ndrange_id(kernel.global_size, index);
297
            kernel.kern(param, {thread_id, ndrange_id});
298 299 300 301 302 303 304 305
        };
        static_cast<naive::HandleImpl*>(fallback_handle)
                ->dispatch_kern(run, kernel.global_size.total_size());
    }
}

ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
        const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
M
Megvii Engine Team 已提交
306
        const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
307 308 309 310 311 312
    auto algo_data_type = param.deduce_algo_data_type();
    auto suggest_category_order = suggest_algo_category_order(param);
    for (auto category : suggest_category_order) {
        auto&& origin_algos = select_algo_type({algo_data_type, category});
        ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
        for (auto i : origin_algos) {
313
            bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
314 315
                    param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
                    negative_attr);
M
Megvii Engine Team 已提交
316 317
            if (usable_attribute && static_cast<AlgoBase*>(i)->get_workspace(param) <=
                                            workspace_limit_in_bytes) {
318 319 320 321 322 323 324 325 326 327 328 329 330
                //! store the first usable algo if no prefer algo, choose it as
                //! the target algo
                if (!heuristic_algo) {
                    heuristic_algo = i;
                }
                //! choose the first prefer algo
                if (i->is_preferred(param)) {
                    return i;
                }
            }
        }
        if (heuristic_algo) {
            return heuristic_algo;
331 332 333 334 335
        }
    }
    return nullptr;
}

M
Megvii Engine Team 已提交
336 337
std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_with_ncb(
        const NCBKernSizeParam& param) {
338 339
    std::vector<Algorithm*> ret;
    std::vector<Algorithm*> prefer_algos;
340
    for (auto&& i : get_all_packed_algo()) {
341 342
        if (i->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
            if (i->is_preferred(param)) {
343 344 345 346 347 348 349 350 351 352
                prefer_algos.push_back(i);
            } else {
                ret.push_back(i);
            }
        }
    }
    ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
    return ret;
}

353 354
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc(
        const AlgorithmDesc& desc) {
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
    if (!desc.valid()) {
        return nullptr;
    } else {
        switch (desc.handle_type) {
            case Handle::HandleType::FALLBACK: {
                const auto& map = algo_pack().all_algos_map();
                megdnn_assert(map.find(desc) != map.end());
                return map.at(desc);
            }
            case Handle::HandleType::NAIVE: {
                auto algo = static_cast<naive::HandleImpl*>(handle())
                                    ->default_conv_fwd_algo();
                megdnn_assert(algo->info().desc == desc);
                return algo;
            }
            default:
                megdnn_throw("Unknown handle type");
                return nullptr;
        }
    }
}

377 378
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
        const NCBKernSizeParam& param, size_t workspace_size) {
379
    if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
380
        return algo;
381 382 383
    }
    if (!m_prev_selected_algo ||
        memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
384
        m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
M
Megvii Engine Team 已提交
385
                param, workspace_size, AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
386 387 388 389 390
        m_prev_selected_algo_sizep = param;
    }
    return m_prev_selected_algo;
}

391 392 393 394
SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order(
        const NCBKernSizeParam& param) const {
    static CpuOprDelegationStorage<1> storage;
    auto conv_bias_opr = storage.get<ConvBias, 0>();
M
Megvii Engine Team 已提交
395
    auto conv_bias_param = ConvolutionImpl::AlgoDefault::init_conv_bias_param(param);
396 397 398 399
    return static_cast<ConvBiasImpl*>(conv_bias_opr)
            ->suggest_algo_category_order(conv_bias_param);
}

400 401 402 403 404
const char* ConvolutionImpl::get_algorithm_set_name() const {
    // fallback version 0
    return "F0";
}

M
Megvii Engine Team 已提交
405 406
ConvolutionImpl::AlgoDataType ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type()
        const {
407 408 409 410 411 412
    if (src_type.enumv() == DTypeEnum::Float32) {
        return ConvolutionImpl::AlgoDataType::FLOAT32;
#if !MEGDNN_DISABLE_FLOAT16
    } else if (src_type.enumv() == DTypeEnum::Float16) {
        return ConvolutionImpl::AlgoDataType::FLOAT16;
#endif
M
Megvii Engine Team 已提交
413 414 415
    } else if (
            src_type.enumv() == DTypeEnum::Int8 ||
            src_type.enumv() == DTypeEnum::QuantizedS8) {
416 417 418 419 420 421 422
        if (dst_type.enumv() == DTypeEnum::Int16) {
            return ConvolutionImpl::AlgoDataType::INT8X8X16;
        } else {
            return ConvolutionImpl::AlgoDataType::QINT8X8X32;
        }
    } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
        return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
423 424 425
    } else if (
            src_type.enumv() == DTypeEnum::QuantizedS4 ||
            src_type.enumv() == DTypeEnum::Quantized4Asymm) {
426
        return ConvolutionImpl::AlgoDataType::QINT4x4x32;
427
    } else {
M
Megvii Engine Team 已提交
428 429 430
        megdnn_throw(ssprintf(
                "not support data type of %s * %s -> %s\n", src_type.name(),
                filter_type.name(), dst_type.name()));
431 432 433
    }
}

434 435
/* ===================== ConvolutionBackwardData ===================== */

436 437 438 439
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
    AlgoNaive algo_naive;
    AlgoDirect algo_direct;
    AlgoMatrixMul algo_matmul;
W
wangxiang 已提交
440
    AlgoMatrixMulNCHW44 algo_matmul_nchw44;
441 442
    SmallVector<AlgoBase*> m_all_algos;
    AlgoBase::Mapper m_all_algos_map;
443 444 445

public:
    AlgoPack() {
W
wangxiang 已提交
446
        m_all_algos.emplace_back(&algo_matmul_nchw44);
447 448 449 450 451 452 453
        m_all_algos.emplace_back(&algo_matmul);
        m_all_algos.emplace_back(&algo_direct);
        m_all_algos.emplace_back(&algo_naive);

        for (auto&& algo : m_all_algos) {
            m_all_algos_map.emplace(algo->info().desc, algo);
        }
454
    }
455 456
    const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
457
};
M
Megvii Engine Team 已提交
458
const ConvolutionBackwardDataImpl::AlgoPack& ConvolutionBackwardDataImpl::algo_pack() {
459 460 461
    static AlgoPack algo_pack;
    return algo_pack;
}
462

M
Megvii Engine Team 已提交
463 464
SmallVector<ConvolutionBackwardDataImpl::AlgoBase*> ConvolutionBackwardDataImpl::
        get_all_packed_algo() {
465
    return algo_pack().all_algos();
466
}
467

M
Megvii Engine Team 已提交
468 469 470
void ConvolutionBackwardDataImpl::exec(
        _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
        _megdnn_workspace workspace) {
471
    if (param().format == param::Convolution::Format::NHWCD4 ||
472
        param().format == param::Convolution::Format::NCHW4 ||
473 474
        ((param().format == param::Convolution::Format::NCHW ||
          param().format == param::Convolution::Format::NHWC) &&
475
         grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
M
Megvii Engine Team 已提交
476
        return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace);
477 478 479 480 481 482 483 484
    }
    auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
    return exec_with_ncb_kern(fparam);
}

size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
        const TensorLayout& filter, const TensorLayout& diff,
        const TensorLayout& grad) {
485
    TensorLayoutArray layouts{filter, diff, grad};
486
    AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
M
Megvii Engine Team 已提交
487 488
                            layouts.data(), layouts.size(),
                            &this->param(), sizeof(this->param())};
489
    auto rst = AlgorithmCache::instance().get(key);
490 491 492 493
    if (rst.policy.algo.valid()) {
        return rst.workspace;
    }

494
    if (param().format == param::Convolution::Format::NHWCD4 ||
495
        param().format == param::Convolution::Format::NCHW4 ||
496 497
        ((param().format == param::Convolution::Format::NCHW ||
          param().format == param::Convolution::Format::NHWC) &&
498
         grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
499 500 501 502 503 504 505
        return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
                filter, diff, grad);
    }
    auto fparam = make_ncb_kern_size_param(filter, diff, grad);
    return get_workspace_with_ncb(fparam);
}

M
Megvii Engine Team 已提交
506 507 508 509
std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
        get_all_algorithms(
                const TensorLayout& filter, const TensorLayout& diff,
                const TensorLayout& grad) {
510
    if (param().format == param::Convolution::Format::NHWCD4 ||
511
        param().format == param::Convolution::Format::NCHW4 ||
512 513
        ((param().format == param::Convolution::Format::NCHW ||
          param().format == param::Convolution::Format::NHWC) &&
514
         grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
515 516 517 518 519 520 521 522
        return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
                filter, diff, grad);
    }
    auto fparam = make_ncb_kern_size_param(filter, diff, grad);
    auto ret = get_all_algorithms_with_ncb(fparam);
    return ret;
}

M
Megvii Engine Team 已提交
523 524 525 526 527
std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
        get_all_algorithms_safe(
                const TensorLayout& filter, const TensorLayout& diff,
                const TensorLayout& grad) {
    auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter, diff, grad);
528 529 530 531
    megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm");
    return ret_safe;
}

M
Megvii Engine Team 已提交
532 533 534 535 536 537
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
        get_algorithm_heuristic(
                const TensorLayout& filter, const TensorLayout& diff,
                const TensorLayout& grad, size_t workspace_limit_in_bytes,
                const AlgoAttribute& positive_attr,
                const AlgoAttribute& negative_attr) {
538
    if (param().format == param::Convolution::Format::NHWCD4 ||
539
        param().format == param::Convolution::Format::NCHW4 ||
540 541
        ((param().format == param::Convolution::Format::NCHW ||
          param().format == param::Convolution::Format::NHWC) &&
542
         grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
543
        return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
544 545
                filter, diff, grad, workspace_limit_in_bytes, positive_attr,
                negative_attr);
546 547
    }
    auto fparam = make_ncb_kern_size_param(filter, diff, grad);
M
Megvii Engine Team 已提交
548 549
    return get_algorithm_heuristic_with_ncb(
            fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
550 551
}

M
Megvii Engine Team 已提交
552 553 554 555
ConvolutionBackwardDataImpl::NCBKernSizeParam ConvolutionBackwardDataImpl::
        make_ncb_kern_size_param(
                const TensorLayout& filter, const TensorLayout& diff,
                const TensorLayout& grad) {
556
    auto safe_u32 = [](size_t v) -> uint32_t {
M
Megvii Engine Team 已提交
557 558
        megdnn_assert(
                v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
559 560 561
        return v;
    };
    size_t spatial_pos;
W
wangxiang 已提交
562 563
    if (param().format == Param::Format::NCHW ||
        param().format == Param::Format::NCHW44) {
564 565
        spatial_pos = 2;
    } else {
M
Megvii Engine Team 已提交
566
        megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
        spatial_pos = 1;
    }
    auto grad_fwd = grad;
    auto filter_fwd = filter;
    auto diff_fwd = diff;

    std::swap(grad_fwd.dtype, diff_fwd.dtype);

    return {
            safe_u32(diff[0]),
            {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
            {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
            check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
            diff.dtype,
            filter.dtype,
            grad.dtype,
            diff,
            filter,
            grad,
            diff.stride[0],
            grad.stride[0],
            0,
            0,
            0,
            param().compute_mode,
    };
}

M
Megvii Engine Team 已提交
595 596 597 598
ConvolutionBackwardDataImpl::NCBKernParam ConvolutionBackwardDataImpl::
        make_ncb_kern_param(
                _megdnn_tensor_in filter, _megdnn_tensor_in diff,
                _megdnn_tensor_out grad, _megdnn_workspace workspace) {
599 600 601 602 603
    NCBKernParam ret;
    static_cast<NCBKernSizeParam&>(ret) =
            make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);

    auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
M
Megvii Engine Team 已提交
604 605 606 607
    megdnn_assert(
            workspace.size >= required_workspace_in_bytes,
            "required workspace: %zu; provided workspace: %zu",
            required_workspace_in_bytes, workspace.size);
608 609 610
    ret.filter_ptr = filter.get_ref_ptr();
    ret.diff_ptr = diff.get_ref_ptr();
    ret.grad_ptr = grad.get_ref_ptr();
611 612 613 614 615
    ret.workspace_ptr = workspace.raw_ptr;
    ret.workspace_size = workspace.size;
    return ret;
}

M
Megvii Engine Team 已提交
616
void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) {
617 618 619
    auto p1g = param;
    auto group = p1g.filter_meta.group;
    p1g.filter_meta.group = 1;
620
    auto&& algo = get_algorithm(p1g);
621
    auto kptr = ncb_1g_dispatch_kern(algo, p1g);
622
    if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) {
623 624 625
        auto run = [kptr, param]() { kptr(param); };
        static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
    } else {
M
Megvii Engine Team 已提交
626 627
        megdnn_assert(
                p1g.filter_meta.format == Param::Format::NCHW ||
W
wangxiang 已提交
628 629
                        p1g.filter_meta.format == Param::Format::NHWC ||
                        p1g.filter_meta.format == Param::Format::NCHW44,
M
Megvii Engine Team 已提交
630
                "invalid conv format");
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
        auto run = [kptr, p1g_orig = p1g, group]() {
            auto p1g = p1g_orig;
            ptrdiff_t istrd, fstrd, ostrd;
            fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
                    p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
                    p1g.filter_type.size();
            istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
            ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
            p1g.diff_extra_mem_size =
                    (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
            p1g.filter_extra_mem_size =
                    (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
                    p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
                    p1g.filter_type.size();
            p1g.grad_extra_mem_size =
                    (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
W
wangxiang 已提交
647 648
            if (p1g.filter_meta.format == Param::Format::NCHW ||
                p1g.filter_meta.format == Param::Format::NCHW44) {
649 650 651 652 653 654 655 656 657
                istrd *= p1g.isz[0] * p1g.isz[1];
                ostrd *= p1g.osz[0] * p1g.osz[1];
                p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
                p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
            } else {
                // must be NHWC. No action performed.
            }
            for (size_t i = 0; i < group; ++i) {
                kptr(p1g);
658 659 660
                p1g.diff_ptr += istrd;
                p1g.filter_ptr += fstrd;
                p1g.grad_ptr += ostrd;
661 662 663 664 665 666 667 668 669 670 671 672 673 674
                p1g.diff_extra_mem_size -= istrd;
                p1g.filter_extra_mem_size -= fstrd;
                p1g.grad_extra_mem_size -= ostrd;
            }
        };
        static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
    }
}

size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
        const NCBKernSizeParam& param) {
    if (param.filter_meta.group != 1) {
        auto p1g = param;
        p1g.filter_meta.group = 1;
675 676
        auto algo = get_algorithm(p1g);
        return ncb_1g_get_workspace(algo, p1g);
677
    }
678 679
    auto algo = get_algorithm(param);
    return ncb_1g_get_workspace(algo, param);
680 681
}

M
Megvii Engine Team 已提交
682 683
std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
        get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
684 685 686 687 688 689 690 691
    if (param.filter_meta.group != 1) {
        auto p1g = param;
        p1g.filter_meta.group = 1;
        return ncb_1g_get_all_algorithms(p1g);
    }
    return ncb_1g_get_all_algorithms(param);
}

M
Megvii Engine Team 已提交
692 693 694 695 696
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
        get_algorithm_heuristic_with_ncb(
                const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
                const AlgoAttribute& positive_attr,
                const AlgoAttribute& negative_attr) {
697 698 699
    if (param.filter_meta.group != 1) {
        auto p1g = param;
        p1g.filter_meta.group = 1;
M
Megvii Engine Team 已提交
700 701
        return ncb_1g_get_algorithm_heuristic(
                p1g, workspace_limit_in_bytes, positive_attr, negative_attr);
702
    }
M
Megvii Engine Team 已提交
703 704
    return ncb_1g_get_algorithm_heuristic(
            param, workspace_limit_in_bytes, positive_attr, negative_attr);
705 706 707 708 709
}

size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
        Algorithm* algo, const NCBKernSizeParam& param) {
    megdnn_assert(param.filter_meta.group == 1);
710
    if (algo->handle_type() == Handle::HandleType::FALLBACK) {
711 712 713 714 715
        return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
    }
    return 0;
}

M
Megvii Engine Team 已提交
716 717
ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::
        ncb_1g_dispatch_kern(Algorithm* algo, const NCBKernSizeParam& param) {
718 719
    megdnn_assert(param.filter_meta.group == 1);

720
    if (algo->handle_type() == Handle::HandleType::FALLBACK) {
721 722 723
        return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
    }

M
Megvii Engine Team 已提交
724
    megdnn_throw("no suitable ConvolutionBackwardData algorithm");
725 726 727 728 729 730 731 732 733 734 735 736
}

bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
        const NCBKernSizeParam& param) {
    auto&& fm = param.filter_meta;
    auto OC = fm.ocpg, IC = fm.icpg;

    return (OC * IC >= 32) ||
           (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
            fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
}

M
Megvii Engine Team 已提交
737 738
std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
        ncb_1g_get_all_algorithms(const NCBKernSizeParam& param) {
739
    std::vector<Algorithm*> ret;
740
    std::vector<Algorithm*> prefer_algos;
741
    for (auto&& i : get_all_packed_algo()) {
742 743 744
        if (i->usable(this, param)) {
            if (i->is_preferred(param)) {
                prefer_algos.push_back(i);
745
            } else {
746
                ret.push_back(i);
747 748 749
            }
        }
    }
750
    ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
751 752 753
    return ret;
}

M
Megvii Engine Team 已提交
754 755 756 757 758
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
        ncb_1g_get_algorithm_heuristic(
                const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
                const AlgoAttribute& positive_attr,
                const AlgoAttribute& negative_attr) {
759 760
    for (auto i : ncb_1g_get_all_algorithms(param)) {
        if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
761 762
            if (i->contain_attribute_all(positive_attr) &&
                !i->contain_attribute_any(negative_attr)) {
763 764 765 766
                return i;
            }
        }
    }
M
Megvii Engine Team 已提交
767
    megdnn_assert(0, "no suitable algorithm found within given workspace limit");
768 769
}

M
Megvii Engine Team 已提交
770 771
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
        get_algorithm_from_desc(const AlgorithmDesc& desc) {
772 773 774 775 776 777 778 779 780 781 782 783 784
    if (!desc.valid()) {
        return nullptr;
    } else {
        switch (desc.handle_type) {
            case Handle::HandleType::FALLBACK: {
                const auto& map = algo_pack().all_algos_map();
                megdnn_assert(map.find(desc) != map.end());
                return map.at(desc);
            }
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
            case Handle::HandleType::ARM_COMMON:
            case Handle::HandleType::AARCH64:
            case Handle::HandleType::ARMV7:
M
Megvii Engine Team 已提交
785 786
                return arm_common::ConvolutionBackwardDataImpl::get_algo_from_desc(
                        desc);
787 788 789 790 791 792 793 794 795 796 797 798 799 800
#endif
            case Handle::HandleType::NAIVE: {
                auto algo = static_cast<naive::HandleImpl*>(handle())
                                    ->default_conv_bwd_data_algo();
                megdnn_assert(algo->info().desc == desc);
                return algo;
            }
            default:
                megdnn_throw("Unknown handle type");
                return nullptr;
        }
    }
}

M
Megvii Engine Team 已提交
801 802
ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::get_algorithm(
        const NCBKernSizeParam& param) {
803
    if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
804
        return algo;
805 806 807 808
    }
    if (!m_prev_selected_algo ||
        memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
        m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
M
Megvii Engine Team 已提交
809 810
                param, std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT,
                AlgoAttribute::DEFAULT);
811 812 813 814 815 816 817 818 819 820 821
        m_prev_selected_algo_sizep = param;
    }
    return m_prev_selected_algo;
}

const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
    // fallback version 0
    return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
}

// vim: syntax=cpp.doxygen