opr_proxy.h 23.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/test/common/opr_proxy.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13
 */
#pragma once

14 15
#include "src/common/opr_trait.h"

16 17
#include "test/common/deduce_layout_proxy.h"
#include "test/common/exec_proxy.h"
18
#include "test/common/fast_run_cache.h"
19
#include "test/common/inspect_type.h"
20
#include "test/common/opr_algo_proxy.h"
21 22 23 24
#include "test/common/timer.h"
#include "test/common/workspace_wrapper.h"

#include <algorithm>
25
#include <limits>
26
#include <memory>
27
#include <unordered_map>
28

29 30 31
namespace megdnn {
namespace test {

32 33 34 35 36 37
template <Algorithm::OprType>
struct OprFromOprTypeTrait;

template <typename Opr>
struct OprTypeFromOprTrait;

M
Megvii Engine Team 已提交
38 39 40 41 42 43 44 45
#define cb(_opr_type, _opr)                                                           \
    template <>                                                                       \
    struct OprFromOprTypeTrait<Algorithm::OprType::_opr_type> {                       \
        using Opr = megdnn::_opr;                                                     \
    };                                                                                \
    template <>                                                                       \
    struct OprTypeFromOprTrait<megdnn::_opr> {                                        \
        constexpr static Algorithm::OprType opr_type = Algorithm::OprType::_opr_type; \
46 47 48
    }

cb(MATRIX_MUL_FORWARD, MatrixMulForward);
49
cb(BATCHED_MATRIX_MUL_FORWARD, BatchedMatrixMulForward);
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
cb(CONVOLUTION_FORWARD, ConvolutionForward);
cb(CONVOLUTION_BACKWARD_DATA, ConvolutionBackwardData);
cb(CONVOLUTION_BACKWARD_FILTER, ConvolutionBackwardFilter);
cb(CONVOLUTION3D_FORWARD, Convolution3DForward);
cb(CONVOLUTION3D_BACKWARD_DATA, Convolution3DBackwardData);
cb(CONVOLUTION3D_BACKWARD_FILTER, Convolution3DBackwardFilter);
cb(LOCAL_SHARE_FORWARD, LocalShareForward);
cb(LOCAL_SHARE_BACKWARD_DATA, LocalShareBackwardData);
cb(LOCAL_SHARE_BACKWARD_FILTER, LocalShareBackwardFilter);
cb(DEFORMABLE_CONV_FORWARD, DeformableConvForward);
cb(DEFORMABLE_CONV_BACKWARD_DATA, DeformableConvBackwardData);
cb(DEFORMABLE_CONV_BACKWARD_FILTER, DeformableConvBackwardFilter);
cb(BATCH_CONV_FORWARD, BatchConvBiasForward);
cb(CONVBIAS_FORWARD, ConvBiasForward);

#undef cb

// clang-format off
#define FOREACH_OPR_TYPE(cb) \
    cb(MATRIX_MUL_FORWARD) \
70
    cb(BATCHED_MATRIX_MUL_FORWARD) \
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    cb(CONVOLUTION_FORWARD) \
    cb(CONVOLUTION_BACKWARD_DATA) \
    cb(CONVOLUTION_BACKWARD_FILTER) \
    cb(CONVOLUTION3D_FORWARD) \
    cb(CONVOLUTION3D_BACKWARD_DATA) \
    cb(CONVOLUTION3D_BACKWARD_FILTER) \
    cb(LOCAL_SHARE_FORWARD) \
    cb(LOCAL_SHARE_BACKWARD_DATA) \
    cb(LOCAL_SHARE_BACKWARD_FILTER) \
    cb(DEFORMABLE_CONV_FORWARD) \
    cb(DEFORMABLE_CONV_BACKWARD_DATA) \
    cb(DEFORMABLE_CONV_BACKWARD_FILTER) \
    cb(BATCH_CONV_FORWARD) \
    cb(CONVBIAS_FORWARD)

#define FOREACH_OPR_TYPE_WITH_STMT(cb, stmt) \
    cb(MATRIX_MUL_FORWARD, stmt) \
88
    cb(BATCHED_MATRIX_MUL_FORWARD, stmt) \
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    cb(CONVOLUTION_FORWARD, stmt) \
    cb(CONVOLUTION_BACKWARD_DATA, stmt) \
    cb(CONVOLUTION_BACKWARD_FILTER, stmt) \
    cb(CONVOLUTION3D_FORWARD, stmt) \
    cb(CONVOLUTION3D_BACKWARD_DATA, stmt) \
    cb(CONVOLUTION3D_BACKWARD_FILTER, stmt) \
    cb(LOCAL_SHARE_FORWARD, stmt) \
    cb(LOCAL_SHARE_BACKWARD_DATA, stmt) \
    cb(LOCAL_SHARE_BACKWARD_FILTER, stmt) \
    cb(DEFORMABLE_CONV_FORWARD, stmt) \
    cb(DEFORMABLE_CONV_BACKWARD_DATA, stmt) \
    cb(DEFORMABLE_CONV_BACKWARD_FILTER, stmt) \
    cb(BATCH_CONV_FORWARD, stmt) \
    cb(CONVBIAS_FORWARD, stmt)

// clang-format on

M
Megvii Engine Team 已提交
106 107 108 109 110
#define _OPR_TYPE_CASE(_opr_type, _stmt)                                               \
    case Algorithm::OprType::_opr_type: {                                              \
        using _Opr = typename OprFromOprTypeTrait<Algorithm::OprType::_opr_type>::Opr; \
        _stmt;                                                                         \
        break;                                                                         \
111 112
    }

M
Megvii Engine Team 已提交
113 114 115 116 117 118 119 120
#define FOREACH_OPR_TYPE_DISPATCH(_search_items, _stmt)                         \
    for (size_t _item_idx = 0; _item_idx < _search_items.size(); _item_idx++) { \
        auto&& _item = _search_items[_item_idx];                                \
        switch (_item.opr_type) {                                               \
            FOREACH_OPR_TYPE_WITH_STMT(_OPR_TYPE_CASE, _stmt)                   \
            default:                                                            \
                megdnn_throw("unknown opr_type");                               \
        }                                                                       \
121 122
    }

M
Megvii Engine Team 已提交
123 124 125 126 127 128
template <
        typename Opr, size_t arity = OprTrait<Opr>::arity,
        bool has_workspace = OprTrait<Opr>::has_workspace,
        bool can_deduce_layout = OprTrait<Opr>::can_deduce_layout>
struct OprProxyDefaultImpl : public DeduceLayoutProxy<Opr, arity, can_deduce_layout>,
                             public ExecProxy<Opr, arity, has_workspace> {};
129 130 131 132

template <typename Opr>
struct OprProxy : public OprProxyDefaultImpl<Opr> {};

133 134 135
template <typename Opr>
struct OprWeightPreprocessProxy : public OprProxyDefaultImpl<Opr> {};

136 137 138 139 140
template <typename Opr>
struct OprProxyVectorToSingle {};

template <>
struct OprProxy<ElemwiseForward> {
M
Megvii Engine Team 已提交
141
    static void deduce_layout(ElemwiseForward* opr, TensorLayoutArray& layouts) {
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        megdnn_assert(layouts.size() >= 2);
        auto inp = layouts;
        inp.pop_back();
        opr->deduce_layout(inp, layouts.back());
    }

    static void exec(ElemwiseForward* opr, const TensorNDArray& tensors) {
        megdnn_assert(tensors.size() >= 2);
        auto inp = tensors;
        inp.pop_back();
        opr->exec(inp, tensors.back());
    }
};

template <>
struct OprProxy<ElemwiseMultiType> {
M
Megvii Engine Team 已提交
158
    static void deduce_layout(ElemwiseMultiType* opr, TensorLayoutArray& layouts) {
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        megdnn_assert(layouts.size() >= 2);
        auto inp = layouts;
        inp.pop_back();
        opr->deduce_layout(inp, layouts.back());
    }

    static void exec(ElemwiseMultiType* opr, const TensorNDArray& tensors) {
        megdnn_assert(tensors.size() >= 2);
        auto inp = tensors;
        inp.pop_back();
        opr->exec(inp, tensors.back());
    }
};

template <>
struct OprProxy<ConcatForward> {
175
    WorkspaceWrapper W;
176 177 178 179 180 181 182
    static void deduce_layout(ConcatForward* opr, TensorLayoutArray& layouts) {
        megdnn_assert(layouts.size() >= 2);
        auto inp = layouts;
        inp.pop_back();
        opr->deduce_layout(inp, layouts.back());
    }

183 184 185 186
    void exec(ConcatForward* opr, const TensorNDArray& tensors) {
        if (!W.valid()) {
            W = WorkspaceWrapper(opr->handle(), 0);
        }
187 188 189 190 191
        megdnn_assert(tensors.size() >= 2);
        auto inp = tensors;
        inp.pop_back();

        TensorLayoutArray layouts(tensors.size());
M
Megvii Engine Team 已提交
192 193 194
        std::transform(
                tensors.begin(), tensors.end(), layouts.begin(),
                [](const TensorND& tensor) { return tensor.layout; });
195 196 197
        auto inp_layouts = layouts;
        inp_layouts.pop_back();

198
        W.update(opr->get_workspace_in_bytes(inp_layouts, layouts.back()));
199 200 201 202 203 204
        auto inp_tensors = tensors;
        inp_tensors.pop_back();
        opr->exec(inp_tensors, tensors.back(), W.workspace());
    }
};

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
template <>
struct OprProxy<CheckNonFinite> {
    static void deduce_layout(CheckNonFinite* opr, TensorLayoutArray& layouts) {
        megdnn_assert(layouts.size() >= 2);
        auto inp = layouts;
        inp.pop_back();
        opr->deduce_layout(inp, layouts.back());
    }

    static void exec(CheckNonFinite* opr, const TensorNDArray& tensors) {
        megdnn_assert(tensors.size() >= 2);
        auto inps = tensors;
        inps.pop_back();

        WorkspaceWrapper W(
                opr->handle(),
                opr->get_workspace_in_bytes(inps, tensors.back().layout));
        opr->exec(inps, tensors.back(), W.workspace());
    }
};

226 227
template <>
struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
228 229
    WorkspaceWrapper W;
    void exec(SplitForward* opr, const TensorNDArray& tensors) {
230
        megdnn_assert(tensors.size() >= 2);
231 232 233
        if (!W.valid()) {
            W = WorkspaceWrapper(opr->handle(), 0);
        }
234 235 236 237
        auto out = tensors;
        out.erase(out.begin());

        TensorLayoutArray layouts(tensors.size());
M
Megvii Engine Team 已提交
238 239 240
        std::transform(
                tensors.begin(), tensors.end(), layouts.begin(),
                [](const TensorND& tensor) { return tensor.layout; });
241 242 243
        auto out_layouts = layouts;
        out_layouts.erase(out_layouts.begin());

244
        W.update(opr->get_workspace_in_bytes(layouts.front(), out_layouts));
245 246 247 248 249 250 251 252

        auto out_tensors = tensors;
        out_tensors.erase(out_tensors.begin());
        opr->exec(tensors.front(), out_tensors, W.workspace());
    }
};

//! OprProxy impl for tenary oprs with profiling support
253
template <class Opr>
254
struct OprProxyProfilingBase
M
Megvii Engine Team 已提交
255 256
        : public DeduceLayoutProxy<
                  Opr, OprTrait<Opr>::arity, OprTrait<Opr>::can_deduce_layout> {
257
    static constexpr int arity = OprTrait<Opr>::arity;
258 259 260 261 262 263 264 265
    size_t warmup_times = 10, exec_times = 100;

    //! whether to enable profiling
    bool m_profiling;
    WorkspaceWrapper W;

    //! target algo setup by profiler; it can also be directly specified by the
    //! caller
266
    ExecutionPolicy target_execution_policy;
267 268

    OprProxyProfilingBase(bool profile = false) { m_profiling = profile; }
269 270 271 272 273 274

    //! used for alloc tensor for weight preprocess
    static std::shared_ptr<TensorNDArray> alloc_tensors(
            Handle* handle, const TensorLayoutArray& layouts) {
        auto deleter = [handle](TensorNDArray* ptr) {
            for (auto&& i : *ptr) {
M
Megvii Engine Team 已提交
275
                auto pdata =
276
                        static_cast<dt_byte*>(i.raw_ptr()) + i.layout.span().low_byte;
277 278 279 280 281 282 283
                megdnn_free(handle, pdata);
            }
            delete ptr;
        };
        std::shared_ptr<TensorNDArray> ret{new TensorNDArray, deleter};
        for (size_t i = 0; i < layouts.size(); ++i) {
            auto span = layouts[i].span();
M
Megvii Engine Team 已提交
284 285 286 287
            ret->emplace_back(
                    static_cast<dt_byte*>(megdnn_malloc(handle, span.dist_byte())) -
                            span.low_byte,
                    layouts[i]);
288 289 290
        }
        return ret;
    }
291

292 293 294 295 296 297 298 299 300 301 302 303 304
    /**
     * flatten search space in postorder traversal
     * The subopr search construct a search tree
     *
     *           A
     *        /    \
     *       B1B2   C
     *      /     \
     *     D1D2D3   E
     * We use postorder traverse the search tree.
     * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
     */
    static std::vector<Algorithm::SearchItem> flatten_search_space(
M
Megvii Engine Team 已提交
305
            const TensorLayoutArray layouts, const std::string& param, Handle* handle) {
306 307
        megdnn_assert(layouts.size() == arity);
        auto opr = handle->create_operator<Opr>();
M
Megvii Engine Team 已提交
308
        opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
309 310

        std::vector<Algorithm::SearchItem> ret;
M
Megvii Engine Team 已提交
311 312
        for (auto algo_info :
             AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
            Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc);
            std::vector<Algorithm::SearchItem>&& sub_items =
                    algo->get_subopr_list(layouts, opr.get());

            FOREACH_OPR_TYPE_DISPATCH(sub_items, {
                auto space = OprProxyProfilingBase<_Opr>::flatten_search_space(
                        _item.layouts, _item.param, handle);
                ret.insert(ret.end(), space.begin(), space.end());
            });
        }
        ret.push_back({OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
        return ret;
    }

    static void construct_execution_policy(
M
Megvii Engine Team 已提交
328 329
            const TensorLayoutArray& layouts, const std::string& param, Handle* handle,
            FastRunCache& cache, ExecutionPolicy& policy) {
330 331
        megdnn_assert(layouts.size() == arity);
        auto opr = handle->create_operator<Opr>();
M
Megvii Engine Team 已提交
332
        opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
333 334 335
        if (!policy.algo.valid()) {
            policy.algo = cache.get(Algorithm::SearchItem{
                    OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
M
Megvii Engine Team 已提交
336 337 338 339
            megdnn_assert(
                    policy.algo.valid(),
                    "No cache found, maybe some error occured in "
                    "flatten_search_space or get_subopr_list");
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
        }
        policy.sub_policy.clear();
        Algorithm* algo = opr->get_algorithm_from_desc(policy.algo);
        std::vector<Algorithm::SearchItem>&& sub_items =
                algo->get_subopr_list(layouts, opr.get());
        FOREACH_OPR_TYPE_DISPATCH(sub_items, {
            policy.sub_policy.push_back({});
            OprProxyProfilingBase<_Opr>::construct_execution_policy(
                    _item.layouts, _item.param, handle, cache,
                    policy.sub_policy.back());
        });
        return;
    }

    /**
     * \brief search and get the best execution_policy
     */
M
Megvii Engine Team 已提交
357 358 359 360
    static void search(
            const TensorLayoutArray& layouts, const std::string& param,
            WorkspaceWrapper& workspace_wrapper, Handle* handle, size_t warmup_times,
            size_t exec_times, FastRunCache& cache) {
361 362
        megdnn_assert(layouts.size() == arity);
        auto opr = handle->create_operator<Opr>();
363

M
Megvii Engine Team 已提交
364
        opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
        SmallVector<size_t> sizes_in_bytes;
        for (const auto& layout : layouts) {
            sizes_in_bytes.push_back(layout.span().dist_byte());
        }

        float min_time = std::numeric_limits<float>::max();
        Algorithm::Info::Desc best_algo;

        std::string log_info = "Profiling start: ";
        for (auto&& layout : layouts) {
            log_info += layout.to_string() + " ";
        }
        megdnn_log("%s", log_info.c_str());
        best_algo = cache.get(Algorithm::SearchItem{
                OprTypeFromOprTrait<Opr>::opr_type, param, layouts});

        if (best_algo.valid()) {
            auto&& algo = opr->get_algorithm_from_desc(best_algo);
            MEGDNN_MARK_USED_VAR(algo);
            megdnn_log("Find best algo %s in cache", algo->name());
            return;
        }
M
Megvii Engine Team 已提交
387 388
        for (auto algo :
             AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
389 390
            //! construct execution_policy
            opr->execution_policy().algo = algo.desc;
M
Megvii Engine Team 已提交
391 392
            construct_execution_policy(
                    layouts, param, handle, cache, opr->execution_policy());
393

M
Megvii Engine Team 已提交
394 395
            auto workspace_size =
                    AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr.get(), layouts);
396 397 398 399 400 401 402 403 404 405 406
            sizes_in_bytes.push_back(workspace_size);

            WorkspaceBundle wb(nullptr, sizes_in_bytes);
            workspace_wrapper.update(wb.total_size_in_bytes());
            wb.set(workspace_wrapper.workspace().raw_ptr);
            TensorNDArray tensors;
            for (size_t i = 0; i < arity; i++) {
                tensors.push_back({wb.get(i), layouts[i]});
            }

            for (size_t times = 0; times < warmup_times; ++times) {
M
Megvii Engine Team 已提交
407 408
                AlgoProxy<Opr, arity>::exec(
                        opr.get(), tensors, wb.get_workspace(arity));
409 410 411 412 413
            }
            megcoreSynchronize(opr->handle()->megcore_computing_handle());
            Timer timer;
            timer.start();
            for (size_t times = 0; times < exec_times; ++times) {
M
Megvii Engine Team 已提交
414 415
                AlgoProxy<Opr, arity>::exec(
                        opr.get(), tensors, wb.get_workspace(arity));
416 417 418
            }
            megcoreSynchronize(opr->handle()->megcore_computing_handle());
            timer.stop();
M
Megvii Engine Team 已提交
419 420
            megdnn_log(
                    "%.3fms %s", timer.get_time_in_us() / 1e3, algo.desc.name.c_str());
421 422 423 424 425 426 427 428 429 430
            if (min_time > timer.get_time_in_us()) {
                min_time = timer.get_time_in_us();
                best_algo = algo.desc;
            }

            sizes_in_bytes.pop_back();
        }
        auto&& algo = opr->get_algorithm_from_desc(best_algo);
        MEGDNN_MARK_USED_VAR(algo);
        megdnn_log("Profiling end, got best algo: %s", algo->name());
M
Megvii Engine Team 已提交
431 432 433 434
        cache.put(
                Algorithm::SearchItem{
                        OprTypeFromOprTrait<Opr>::opr_type, param, layouts},
                best_algo);
435 436
    }

437
    void exec(Opr* opr, const TensorNDArray& tensors) {
438 439 440
        megdnn_assert(tensors.size() == arity);
        if (!W.valid()) {
            W = WorkspaceWrapper(opr->handle(), 0);
441
        }
442 443 444
        TensorLayoutArray layouts;
        for (auto&& tensor : tensors) {
            layouts.push_back(tensor.layout);
445
        }
446 447 448 449 450 451 452
        if (m_profiling && !target_execution_policy.algo.valid()) {
            FastRunCache cache;
            std::string param_str;
            Algorithm::serialize_write_pod(opr->param(), param_str);
            auto&& search_items =
                    flatten_search_space(layouts, param_str, opr->handle());
            FOREACH_OPR_TYPE_DISPATCH(search_items, {
453
                OprProxyProfilingBase<_Opr>::search(
M
Megvii Engine Team 已提交
454 455
                        _item.layouts, _item.param, W, opr->handle(), warmup_times,
                        exec_times, cache);
456 457
            });

M
Megvii Engine Team 已提交
458 459
            construct_execution_policy(
                    layouts, param_str, opr->handle(), cache, opr->execution_policy());
460
            target_execution_policy = opr->execution_policy();
461 462 463
            auto workspace_size =
                    AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
            W.update(workspace_size);
464
        }
465
        if (!target_execution_policy.algo.valid()) {
466 467 468
            auto workspace_size =
                    AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
            W.update(workspace_size);
469
        }
470
        AlgoProxy<Opr, arity>::exec(opr, tensors, W.workspace());
471 472 473
    }
};

474 475 476 477
#define DEF_PROF(c)                                            \
    template <>                                                \
    struct OprProxy<c> : public OprProxyProfilingBase<c> {     \
        using OprProxyProfilingBase<c>::OprProxyProfilingBase; \
478
    }
479

480 481 482 483 484 485 486
DEF_PROF(MatrixMulForward);
DEF_PROF(ConvolutionForward);
DEF_PROF(ConvolutionBackwardData);
DEF_PROF(ConvolutionBackwardFilter);
DEF_PROF(LocalShareForward);
DEF_PROF(LocalShareBackwardData);
DEF_PROF(LocalShareBackwardFilter);
487

488 489 490 491
DEF_PROF(DeformableConvForward);
DEF_PROF(DeformableConvBackwardFilter);
DEF_PROF(BatchConvBiasForward);
DEF_PROF(ConvBiasForward);
492

493
DEF_PROF(DeformableConvBackwardData);
494
#undef DEF_PROF
495

496 497 498 499
template <class Opr>
struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> {
    using Base = OprProxyProfilingBase<Opr>;
    static constexpr int arity = OprTrait<Opr>::arity;
500
    void exec(Opr* opr, const TensorNDArray& tensors) {
501
        megdnn_assert(tensors.size() == arity);
502 503 504 505
        if (!Base::W.valid()) {
            Base::W = WorkspaceWrapper(opr->handle(), 0);
        }

506 507 508
        TensorLayoutArray layouts;
        for (auto&& tensor : tensors) {
            layouts.push_back(tensor.layout);
509
        }
510
        if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) {
511
            size_t min_time = std::numeric_limits<size_t>::max();
512
            for (auto algo :
513
                 AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) {
514
                opr->execution_policy().algo = algo.desc;
515

M
Megvii Engine Team 已提交
516
                auto preprocess_tensors = weight_prerocess(opr, tensors, algo.desc);
517
                megcoreSynchronize(opr->handle()->megcore_computing_handle());
518
                typename Opr::PreprocessedFilter preprocessed_filter{
519
                        nullptr, *preprocess_tensors};
520

M
Megvii Engine Team 已提交
521 522
                auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
                        opr, layouts, &preprocessed_filter);
523 524
                Base::W.update(workspace_size);

525
                for (size_t times = 0; times < Base::warmup_times; ++times) {
M
Megvii Engine Team 已提交
526 527
                    AlgoProxy<Opr, arity>::exec(
                            opr, tensors, &preprocessed_filter, Base::W.workspace());
528
                }
529 530 531 532
                megcoreSynchronize(opr->handle()->megcore_computing_handle());
                Timer timer;
                timer.start();
                for (size_t times = 0; times < Base::exec_times; ++times) {
M
Megvii Engine Team 已提交
533 534
                    AlgoProxy<Opr, arity>::exec(
                            opr, tensors, &preprocessed_filter, Base::W.workspace());
535 536 537 538
                }
                megcoreSynchronize(opr->handle()->megcore_computing_handle());
                timer.stop();
                printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
539
                       algo.desc.name.c_str());
540 541
                if (min_time > timer.get_time_in_us()) {
                    min_time = timer.get_time_in_us();
542
                    Base::target_execution_policy.algo = algo.desc;
543 544
                }
            }
545
            opr->execution_policy() = Base::target_execution_policy;
M
Megvii Engine Team 已提交
546 547
            auto preprocess_tensors =
                    weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
548
            megcoreSynchronize(opr->handle()->megcore_computing_handle());
549
            typename Opr::PreprocessedFilter preprocessed_filter{
550
                    nullptr, *preprocess_tensors};
551 552
            auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
                    opr, layouts, &preprocessed_filter);
553 554
            Base::W.update(workspace_size);
        }
M
Megvii Engine Team 已提交
555 556
        auto preprocess_tensors =
                weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
557
        megcoreSynchronize(opr->handle()->megcore_computing_handle());
558
        typename Opr::PreprocessedFilter preprocessed_filter{
559
                nullptr, *preprocess_tensors};
560
        if (!Base::target_execution_policy.algo.valid()) {
561 562
            auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
                    opr, layouts, &preprocessed_filter);
563 564
            Base::W.update(workspace_size);
        }
M
Megvii Engine Team 已提交
565 566
        AlgoProxy<Opr, arity>::exec(
                opr, tensors, &preprocessed_filter, Base::W.workspace());
567 568 569 570
    }

    //! handle weight preprocess
    std::shared_ptr<TensorNDArray> weight_prerocess(
571 572 573 574 575 576 577
            Opr* opr, const TensorNDArray& tensors,
            const typename Opr::AlgorithmDesc&) {
        TensorLayoutArray layouts;
        for (auto&& tensor : tensors) {
            layouts.push_back(tensor.layout);
        }
        auto weight_perprocess_layouts =
M
Megvii Engine Team 已提交
578
                AlgoProxy<Opr, arity>::deduce_preprocessed_filter_layout(opr, layouts);
579
        auto preprocessed_filter_tensors_ptr =
580 581
                Base::alloc_tensors(opr->handle(), weight_perprocess_layouts);
        typename Opr::PreprocessedFilter preprocessed_filter{
582
                nullptr, *preprocessed_filter_tensors_ptr};
583
        size_t preprocess_workspace_size =
M
Megvii Engine Team 已提交
584 585
                AlgoProxy<Opr, arity>::get_preprocess_workspace_in_bytes(opr, layouts);
        WorkspaceWrapper preprocess_workspace(opr->handle(), preprocess_workspace_size);
586 587 588
        AlgoProxy<Opr, arity>::exec_preprocess(
                opr, tensors, layouts, &preprocessed_filter,
                preprocess_workspace.workspace());
589 590 591 592
        return preprocessed_filter_tensors_ptr;
    }
};

M
Megvii Engine Team 已提交
593 594 595 596
#define DEF_PROF(c)                                                               \
    template <>                                                                   \
    struct OprWeightPreprocessProxy<c> : public OprWeightPreprocessProxyImpl<c> { \
        using OprWeightPreprocessProxyImpl<c>::OprWeightPreprocessProxyImpl;      \
597 598
    }

599 600
DEF_PROF(ConvolutionForward);
DEF_PROF(ConvBias);
601
#undef DEF_PROF
602 603 604 605 606

}  // namespace test
}  // namespace megdnn

// vim: syntax=cpp.doxygen