profiler_impl.cpp 23.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/**
 * \file src/gopt/impl/profiler_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "./opr_format_modifier.h"
#include "./utils.h"
#include "megbrain/gopt/framework.h"
16
#include "megbrain/gopt/profiler.h"
17 18 19 20
#include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/io.h"
21
#include "megbrain/opr/nn_int.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
#include "megbrain/plugin/base.h"
#include "megbrain/serialization/sereg.h"

using namespace mgb;
using namespace cg;
using namespace opr;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;

namespace {
using OprFormat = Problem::OprFormat;
OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) {
    switch (tensor_format) {
        case TensorFormats::NCHW:
            return OprFormat::NCHW;
        case TensorFormats::NCHWc4:
38
            return OprFormat::NCHW44;
39
        case TensorFormats::NCHWc8:
40
            return OprFormat::NCHW88;
41 42 43 44 45 46 47 48 49
        case TensorFormats::NCHWc32:
            return OprFormat::NCHW32;
        case TensorFormats::NCHWc64:
            return OprFormat::NCHW64;
        case TensorFormats::NHWC:
            return OprFormat::NHWC;
        case TensorFormats::CHWNc4:
            return OprFormat::CHWN4;
        default:
M
Megvii Engine Team 已提交
50 51 52
            mgb_throw(
                    MegBrainError, "tensor format(%u) is not supported",
                    static_cast<uint32_t>(tensor_format));
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    }
}

class GraphPartitionProfiler final : public PluginBase {
    using CompNodeEventPtr = std::unique_ptr<CompNode::Event>;

public:
    using OprFilter = thin_function<bool(OperatorNodeBase*)>;
    struct OprKernEvent {
        CompNodeEventPtr start, end;
    };
    GraphPartitionProfiler(ComputingGraph* graph, OprFilter opr_filter);
    ~GraphPartitionProfiler() noexcept;
    float duration_in_usec() const;

private:
    void record_event(CompNodeEventPtr& dest, CompNode cn) {
        if (dest == nullptr)
            dest = cn.create_event(CompNode::Event::NEED_TIMER);
        dest->record();
    }
    ThinHashMap<OperatorNodeBase*, OprKernEvent> m_kern_event;
    OprFilter m_opr_filter;
};

M
Megvii Engine Team 已提交
78 79
GraphPartitionProfiler::GraphPartitionProfiler(
        ComputingGraph* graph, OprFilter opr_filter)
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        : PluginBase(graph), m_opr_filter(opr_filter) {
    using namespace event;
    auto on_before_kern = [this](BeforeKernel const& event) {
        if (!m_opr_filter(event.opr))
            return;
        auto evptr = &m_kern_event[event.opr].start;
        record_event(*evptr, event.comp_node);
    };
    auto on_after_kern = [this](AfterKernel const& event) {
        if (!m_opr_filter(event.opr))
            return;
        auto evptr = &m_kern_event[event.opr].end;
        record_event(*evptr, event.comp_node);
    };
    auto&& ev = graph->event();
    add_event_handler(ev.register_receiver<BeforeKernel>(on_before_kern));
    add_event_handler(ev.register_receiver<AfterKernel>(on_after_kern));
}

GraphPartitionProfiler::~GraphPartitionProfiler() noexcept {
    auto wait = [](const CompNodeEventPtr& ev) {
        if (ev)
            ev->host_wait();
    };
    for (auto&& i : m_kern_event) {
        wait(i.second.start);
        wait(i.second.end);
    }
}

float GraphPartitionProfiler::duration_in_usec() const {
    float device_duration = 0.f;
    for (auto&& kern_ev : m_kern_event) {
        auto&& event = kern_ev.second;
        event.end->host_wait();
        device_duration += 1e6 * event.start->elapsed_time_until(*event.end);
    }
    return device_duration;
}

/*!
 * \brief An operator that indicates its input var node is contiguous
 */
// clang-format off
MGB_DEFINE_OPR_CLASS(MarkInputContiguous, SingleCNOperatorNodeBase) //{
    void scn_do_execute() override {};
    void init_output_static_infer_desc() override;
    void add_input_layout_constraint() override {
        input(0)->add_layout_constraint_contiguous();
    }
public:
    MarkInputContiguous(VarNode* input, const OperatorNodeConfig& config);
    static SymbolVar make(SymbolVar input, const OperatorNodeConfig& config = {});
};
// clang-format on

MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkInputContiguous);

M
Megvii Engine Team 已提交
138 139
MarkInputContiguous::MarkInputContiguous(
        VarNode* input, const OperatorNodeConfig& config)
140 141 142 143 144
        : Super(input->owner_graph(), config, "mark_contiguous", {input}) {
    add_input({input});
    add_output(None);
}

M
Megvii Engine Team 已提交
145 146
SymbolVar MarkInputContiguous::make(SymbolVar input, const OperatorNodeConfig& config) {
    return input.insert_single_output_opr<MarkInputContiguous>(input.node(), config);
147 148 149 150 151
}

void MarkInputContiguous::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto&& mgr = owner_graph()->static_infer_manager();
M
Megvii Engine Team 已提交
152
    mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
153 154 155 156
}
}  // namespace

/* ================== ProfilerImpl =================*/
157
ProfilerImpl::ProfilerImpl(int runs, float opr_threshold, float var_node_threshold)
158 159 160
        : m_opr_threshold{opr_threshold},
          m_var_node_threshold{var_node_threshold},
          m_runs{runs} {
161
    m_opr_filter = [this](const OperatorNodeBase* opr, OperatorNodeBase* new_opr) {
162 163 164 165 166
        /// \note: for the considerations of performance, we skip nchw(naive)
        /// kernels for conv bias on CUDA platform. to remove this later
        if (auto conv = try_cast_as_op<opr::ConvBiasForward>(new_opr)) {
            if (conv->output(0)->comp_node().device_type() ==
                        CompNode::DeviceType::CUDA &&
167
                conv->input(0)->dtype().category() == DTypeCategory::QUANTIZED &&
168 169 170 171
                conv->param().format == OprFormat::NCHW) {
                return false;
            }
        }
172 173
        float comp1 =
                m_opr_footprint.get_computation(const_cast<OperatorNodeBase*>(opr));
174 175 176 177 178
        float comp2 = m_opr_footprint.get_computation(new_opr);
        if (comp2 > m_opr_threshold * comp1)
            return false;
        return true;
    };
179 180
    m_var_node_filter = [this](const VarNode* var, TensorShape from, TensorShape to,
                               ReformatKey key) {
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        /// \note: due to the alignment requirement of low-bit tensor, we skip
        /// some layout transform for low-bit tensors. The skipped layout
        /// transforms do not have corresponding dnn kernel and cannot be
        /// implemented by tensor manip operators (like reshape, dimshuffle,
        /// subtensor, etc.).
        if (var->dtype().enumv() == DTypeEnum::QuantizedS4 ||
            var->dtype().enumv() == DTypeEnum::Quantized4Asymm) {
            if (key.input_format == TensorFormats::NCHW &&
                key.output_format != TensorFormats::NHWC &&
                key.output_format != TensorFormats::NCHWc64) {
                return false;
            }
            if (key.output_format == TensorFormats::NCHW &&
                key.input_format != TensorFormats::NHWC &&
                key.input_format != TensorFormats::NCHWc64) {
                return false;
            }
        }
        TensorLayout orig_ly = {var->shape(), var->dtype()},
                     from_ly = {from, var->dtype()}, to_ly = {to, var->dtype()};
        float orig_memory = orig_ly.span().dist_byte() * 2.f;
202
        float reformat_memory = from_ly.span().dist_byte() + to_ly.span().dist_byte();
203 204 205 206 207
        if (reformat_memory > orig_memory * m_var_node_threshold)
            return false;
        return true;
    };
}
208 209 210

ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
        const OperatorNodeBase* opr, TensorFormats base_format,
211 212
        const SmallVector<TensorFormats>& available_tensor_formats,
        ReformatAttribute extra_attribute) const {
213 214 215 216 217
    OperatorNodeRecord record;
    record.opr = opr;
    auto& costs = record.costs;
    for (auto&& f : available_tensor_formats) {
        auto opr_format = tensor_formats_to_opr_format(f);
M
Megvii Engine Team 已提交
218
        costs[opr_format] = profile_operator(opr, base_format, f, extra_attribute);
219 220 221 222
    }
    return record;
}

M
Megvii Engine Team 已提交
223 224 225
float ProfilerImpl::profile_operator(
        const OperatorNodeBase* opr, TensorFormats base_format,
        TensorFormats tensor_format, ReformatAttribute extra_attribute) const {
226 227 228 229 230 231 232 233 234
    auto graph = ComputingGraph::make();
    graph->options().graph_opt_level = 0;
    graph->options().var_sanity_check_first_run = false;
    VarNodeArray new_inps(opr->input().size());
    for (size_t i = 0; i < opr->input().size(); ++i) {
        auto&& var = opr->input(i);
        auto&& cn = var->comp_node();
        auto&& dtype = var->dtype();
        auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
235 236
        auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape(
                var, base_format, tensor_format, extra_attribute);
237 238 239 240 241 242
        dval->resize(aligned_tensor_shape);
        auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
        new_inps[i] = aligned_var.node();
    }
    auto new_opr = serialization::copy_opr_shallow(
            *opr, new_inps, opr->config(), {graph.get()});
243 244
    if (!m_opr_filter(opr, new_opr))
        return PROFILE_TIME_OUT;
245 246 247 248
    auto y = new_opr->output(0);
    auto mark = MarkInputContiguous::make(SymbolVar(y));
    auto func = graph->compile({{mark, {}}});
    auto filter = [new_opr](OperatorNodeBase* opr) { return opr == new_opr; };
M
Megvii Engine Team 已提交
249 250
    auto profiler =
            std::make_unique<GraphPartitionProfiler>(graph.get(), std::move(filter));
251 252 253 254 255 256
    for (int i = 0; i < m_runs; ++i)
        func->execute();
    return profiler->duration_in_usec();
}

ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
M
Megvii Engine Team 已提交
257
        const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config,
258 259
        const SmallVector<OprTensorFormatsConfiguration>& available_configs,
        ReformatAttribute extra_attribute) const {
260 261 262 263
    OperatorNodeRecord record;
    record.opr = opr;
    auto& costs = record.costs;
    for (auto&& i : available_configs) {
M
Megvii Engine Team 已提交
264
        costs[i.opr_format] = profile_operator(opr, base_config, i, extra_attribute);
265 266 267 268 269
    }
    return record;
}

float ProfilerImpl::profile_operator(
M
Megvii Engine Team 已提交
270
        const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config,
271 272
        const OprTensorFormatsConfiguration& config,
        ReformatAttribute extra_attribute) const {
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
    auto graph = ComputingGraph::make();
    graph->options().graph_opt_level = 0;
    graph->options().var_sanity_check_first_run = false;
    VarNodeArray new_inps(opr->input().size());
    size_t i = 0;
    size_t nr_input_tensor =
            std::min(config.input_tensor_formats.size(), opr->input().size());
    for (; i < nr_input_tensor; ++i) {
        auto&& var = opr->input(i);
        auto&& cn = var->comp_node();
        auto&& dtype = var->dtype();
        auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
        TensorShape aligned_shape;
        if (config.input_tensor_types[i] == TensorType::WEIGHT) {
            mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT);
288
            aligned_shape = ReformatManager::make_aligned_weight_shape(
289
                    var, base_config.input_tensor_formats[i],
M
Megvii Engine Team 已提交
290 291
                    config.input_tensor_formats[i], config.output_tensor_formats[0],
                    extra_attribute);
292
        } else {
M
Megvii Engine Team 已提交
293 294 295
            mgb_assert(
                    base_config.input_tensor_types[i] == config.input_tensor_types[i]);
            mgb_assert(base_config.input_tensor_types[i] == TensorType::FEATURE);
296
            aligned_shape = ReformatManager::make_aligned_tensor_shape(
297
                    var, base_config.input_tensor_formats[i],
298
                    config.input_tensor_formats[i], extra_attribute);
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
        }
        dval->resize(aligned_shape);
        auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
        new_inps[i] = aligned_var.node();
    }
    for (; i < opr->input().size(); ++i) {
        auto&& var = opr->input(i);
        auto&& cn = var->comp_node();
        auto&& dtype = var->dtype();
        auto hval = std::make_shared<HostTensorND>(cn, dtype);
        hval->resize(var->shape());
        auto cb = [&](DeviceTensorND& d) { hval->copy_from(d).sync(); };
        {
            auto cg = var->owner_graph();
            cg->compile({{var, cb}})->execute();
        }
        auto imm = opr::ImmutableTensor::make(*graph, *hval);
        new_inps[i] = imm.node();
    }
M
Megvii Engine Team 已提交
318
    VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr);
319 320 321 322 323 324 325 326 327 328 329
#if 0
    static const ThinHashSet<Typeinfo*> multi_algo_oprs = {
            opr::Convolution::typeinfo(),
            opr::ConvBiasForward::typeinfo(),
            opr::ConvolutionBackwardData::typeinfo(),
            opr::PoolingForward::typeinfo(),
    };
    if (multi_algo_oprs.count(opr->dyn_typeinfo()) &&
        !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr()))
        return PROFILE_TIME_OUT;
#endif
330 331
    if (!m_opr_filter(opr, y->owner_opr()))
        return PROFILE_TIME_OUT;
332 333 334 335
    auto mark = MarkInputContiguous::make(SymbolVar(y));
    auto func = graph->compile({{mark, {}}});
    auto new_opr = y->owner_opr();
    auto filter = [&new_opr](OperatorNodeBase* opr) { return opr == new_opr; };
M
Megvii Engine Team 已提交
336 337
    auto profiler =
            std::make_unique<GraphPartitionProfiler>(graph.get(), std::move(filter));
338 339 340 341 342 343 344 345
    for (int i = 0; i < m_runs; ++i)
        func->execute();
    return profiler->duration_in_usec();
}

ProfilerImpl::VarNodeRecord ProfilerImpl::profile_var_node(
        const VarNode* var, TensorFormats base_format,
        const SmallVector<TensorFormats>& available_tensor_formats,
346
        ReformatAttribute attribute) const {
347 348 349 350 351 352 353
    VarNodeRecord record;
    record.var = var;
    auto& costs = record.costs;
    for (auto&& i : available_tensor_formats) {
        for (auto&& o : available_tensor_formats) {
            if (i == o)
                continue;
M
Megvii Engine Team 已提交
354 355
            ReformatKey key{
                    i, o, attribute, var->dtype().enumv(), var->dtype().enumv()};
356 357 358 359 360 361
            costs[{i, o}] = profile_var_node(var, base_format, key);
        }
    }
    return record;
}

M
Megvii Engine Team 已提交
362 363
float ProfilerImpl::profile_var_node(
        const VarNode* var, TensorFormats base_format, const ReformatKey& key) const {
364 365 366
    auto&& cn = var->comp_node();
    auto&& dtype = var->dtype();
    auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
367 368
    auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape(
            var, base_format, key.input_format, key.attribute);
369 370 371 372 373 374 375 376
    dval->resize(aligned_tensor_shape);
    auto graph = ComputingGraph::make();
    graph->options().graph_opt_level = 0;
    graph->options().var_sanity_check_first_run = false;
    auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
    auto builder = ReformatManager::instance().auto_aligned_reformat_featrue(
            var, base_format, key);
    auto y = builder({aligned_var.node()});
377 378

    if (!m_var_node_filter(var, aligned_tensor_shape, y->shape(), key))
379
        return PROFILE_TIME_OUT;
380 381 382 383 384 385 386
    ThinHashSet<OperatorNodeBase*> set;
    DepOprIter iter([&set](OperatorNodeBase* opr) { set.insert(opr); });
    iter.add(y->owner_opr());
    iter.set_visited(aligned_var.node()->owner_opr());
    auto mark = MarkInputContiguous::make(SymbolVar(y));
    auto func = graph->compile({{mark, {}}});
    auto filter = [&set](OperatorNodeBase* opr) { return set.count(opr) > 0; };
M
Megvii Engine Team 已提交
387 388
    auto profiler =
            std::make_unique<GraphPartitionProfiler>(graph.get(), std::move(filter));
389 390 391 392 393
    for (int i = 0; i < m_runs; ++i)
        func->execute();
    return profiler->duration_in_usec();
}

M
Megvii Engine Team 已提交
394
ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) const {
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
    ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
    {
        auto cb = [&cvprop](OperatorNodeBase* opr) { cvprop.add_opr(opr); };
        DepOprIter iter{cb};
        for (auto&& o : problem.graph_partition().output()) {
            iter.add(o->owner_opr());
        }
    }

    static const ThinHashMap<Typeinfo*, size_t> format_aware_input_tensors = {
#define cb(_Opr, _arity) {_Opr::typeinfo(), _arity}
            cb(Convolution, 2),
            cb(ConvBiasForward, 4),
            cb(ConvolutionBackwardData, 2),
            cb(PoolingForward, 1),
            cb(WarpPerspective, 1),
            cb(Resize, 1),
#undef cb
    };
414
    static const ThinHashSet<Typeinfo*> skip_opr_types = {
M
Megvii Engine Team 已提交
415
            TypeCvt::typeinfo(), Elemwise::typeinfo(), ElemwiseMultiType::typeinfo()};
416 417
    ThinHashSet<VarNode*> vars;
    ThinHashSet<OperatorNodeBase*> oprs;
418 419 420 421 422 423 424 425 426
    ThinHashSet<OperatorNodeBase*> skip_oprs;
    for (auto&& opr : problem.graph_partition().all_oprs()) {
        if (cvprop.is_const(opr))
            continue;
        bool skip = true;
        for (auto&& i : opr->input()) {
            skip &= problem.graph_partition().input().count(i) > 0 ||
                    skip_oprs.count(i->owner_opr()) > 0;
        }
427 428
        auto find = format_aware_input_tensors.find(opr->dyn_typeinfo());
        skip &= find == format_aware_input_tensors.end();
429 430 431 432 433 434 435
        if (skip)
            skip_oprs.insert(opr);
        oprs.insert(opr);
        if (find == format_aware_input_tensors.end()) {
            for (auto&& i : opr->input()) {
                if (!cvprop.is_const(i)) {
                    vars.insert(i);
436
                }
437 438
            }
        } else {
M
Megvii Engine Team 已提交
439
            size_t nr_input_tensor = std::min(find->second, opr->input().size());
440 441 442
            for (size_t i = 0; i < nr_input_tensor; ++i) {
                if (!cvprop.is_const(opr->input(i))) {
                    vars.insert(opr->input(i));
443 444 445
                }
            }
        }
446 447
        for (auto&& ov : opr->usable_output()) {
            vars.insert(ov);
448 449 450 451 452
        }
    }

    auto base_format = problem.base_format();
    auto&& available_tensor_formats = problem.available_tensor_formats();
453
    auto&& reformat_attribute = problem.attribute().reformat_attribute;
454 455 456 457 458

    ProfilingResult profiling_result;
    auto& opr_record = profiling_result.opr_record;
    auto& var_record = profiling_result.var_record;
    for (auto&& var : vars) {
459 460
        var_record[var] = profile_var_node(
                var, base_format, available_tensor_formats, reformat_attribute);
461 462 463 464 465
    }
    for (auto&& opr : oprs) {
        auto&& opr_configs = problem.opr_configs();
        auto find = opr_configs.find(opr->dyn_typeinfo());
        if (find == opr_configs.end()) {
466 467
            if (skip_oprs.count(opr) > 0) {
                SmallVector<TensorFormats> tensor_formats = {base_format};
468 469
                opr_record[opr] = profile_operator(
                        opr, base_format, tensor_formats, reformat_attribute);
470
            } else {
M
Megvii Engine Team 已提交
471 472
                opr_record[opr] = profile_operator(
                        opr, base_format, available_tensor_formats, reformat_attribute);
473
            }
474 475 476 477 478 479 480 481 482 483
        } else {
            auto&& dispatchers = find->second;
            SmallVector<OprTensorFormatsConfiguration> configs;
            for (const auto& item : dispatchers) {
                auto config = (*item.second)(opr);
                if (config.valid()) {
                    configs.emplace_back(config.val());
                }
            }
            auto base_config = problem.base_config(opr);
M
Megvii Engine Team 已提交
484 485
            opr_record[opr] =
                    profile_operator(opr, base_config, configs, reformat_attribute);
486 487 488 489 490 491 492 493 494 495 496 497 498
        }
    }
    for (auto&& rpair : opr_record) {
        mgb_log_debug("%s", rpair.second.to_string().c_str());
    }
    for (auto&& rpair : var_record) {
        mgb_log_debug("%s", rpair.second.to_string().c_str());
    }
    return profiling_result;
}

/* ================== ProfilerBase =================*/
std::string ProfilerBase::OperatorNodeRecord::to_string() const {
M
Megvii Engine Team 已提交
499 500 501
    auto str = ssprintf(
            "\nopr type: %s\nopr name: %s\ninputs:\n", opr->dyn_typeinfo()->name,
            opr->cname());
502
    for (auto&& i : opr->input()) {
M
Megvii Engine Team 已提交
503 504
        str += ssprintf(
                "\tvar: %s\n\tshape: %s\n", i->cname(), i->shape().to_string().c_str());
505
    }
M
Megvii Engine Team 已提交
506 507 508
    str += ssprintf(
            "outputs:\n\tvar: %s\n\tshape: %s\ncosts:\n", opr->output(0)->cname(),
            opr->output(0)->shape().to_string().c_str());
509
    for (auto&& cpair : costs) {
M
Megvii Engine Team 已提交
510 511 512
        str += ssprintf(
                "\tformat: %s; cost:%f", opr_format_to_string(cpair.first),
                cpair.second);
513 514 515 516 517 518 519 520
    }
    return str;
}

std::string ProfilerBase::VarNodeRecord::to_string() const {
    auto str = ssprintf("\nvar: %s\ncosts:", var->cname());
    for (auto&& cpair : costs) {
        auto&& formats = cpair.first;
M
Megvii Engine Team 已提交
521 522 523 524 525 526 527
        str += ssprintf(
                "\n\tformat: (i:%s;o:%s); cost:%f",
                tensor_formats_to_named_tensor_shape(formats.first).to_string().c_str(),
                tensor_formats_to_named_tensor_shape(formats.second)
                        .to_string()
                        .c_str(),
                cpair.second);
528 529 530 531 532 533 534 535
    }
    return str;
}

std::unique_ptr<ProfilerBase> ProfilerBase::make_profiler() {
    return std::make_unique<ProfilerImpl>();
}

536
std::unique_ptr<ProfilerBase> ProfilerBase::make_cached_profiler(const char* path) {
537 538 539 540
    return std::make_unique<CachedProfiler>(path);
}

/* ================== CachedProfiler =================*/
541 542
CachedProfiler::CachedProfiler(
        const char* path, int runs, float opr_threshold, float var_node_threshold)
543 544
        : ProfilerImpl(runs, opr_threshold, var_node_threshold), m_path{path} {
    if (m_path != nullptr) {  // file cache
545
        ProfilerCache::inst().set_impl(std::make_unique<InFilePersistentCache>(m_path));
546 547 548
    }
}

549
CachedProfiler::ProfilingResult CachedProfiler::profile(const Problem& problem) const {
550 551 552 553 554 555 556 557 558
    auto ret = ProfilerImpl::profile(problem);
    if (m_path != nullptr)
        ProfilerCache::inst().dump_cache(m_path);
    return ret;
}

float CachedProfiler::profile_operator(
        const OperatorNodeBase* opr, TensorFormats base_format,
        TensorFormats tensor_format, ReformatAttribute extra_attribute) const {
559 560
    ProfilerCache::Key key{
            opr, tensor_formats_to_opr_format(tensor_format), extra_attribute};
561 562 563
    auto ret = ProfilerCache::inst().get(key);
    if (ret.valid())
        return ret.val();
564 565
    auto rst = ProfilerImpl::profile_operator(
            opr, base_format, tensor_format, extra_attribute);
566 567 568 569 570
    ProfilerCache::inst().put(key, rst);
    return rst;
}

float CachedProfiler::profile_operator(
571
        const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config,
572 573 574 575 576 577
        const OprTensorFormatsConfiguration& config,
        ReformatAttribute extra_attribute) const {
    ProfilerCache::Key key{opr, config.opr_format, extra_attribute};
    auto ret = ProfilerCache::inst().get(key);
    if (ret.valid())
        return ret.val();
578 579
    auto rst =
            ProfilerImpl::profile_operator(opr, base_config, config, extra_attribute);
580 581 582 583
    ProfilerCache::inst().put(key, rst);
    return rst;
}

584 585
float CachedProfiler::profile_var_node(
        const VarNode* var, TensorFormats base_format, const ReformatKey& key) const {
586 587 588 589 590 591 592 593 594
    ProfilerCache::Key pf_key{var, key};
    auto ret = ProfilerCache::inst().get(pf_key);
    if (ret.valid())
        return ret.val();
    auto rst = ProfilerImpl::profile_var_node(var, base_format, key);
    ProfilerCache::inst().put(pf_key, rst);
    return rst;
}

595
// vim: syntax=cpp.doxygen