profiler_impl.cpp 23.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/**
 * \file src/gopt/impl/profiler_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "./opr_format_modifier.h"
#include "./utils.h"
#include "megbrain/gopt/framework.h"
16
#include "megbrain/gopt/profiler.h"
17 18 19 20
#include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/io.h"
21
#include "megbrain/opr/nn_int.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
#include "megbrain/plugin/base.h"
#include "megbrain/serialization/sereg.h"

using namespace mgb;
using namespace cg;
using namespace opr;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;

namespace {
using OprFormat = Problem::OprFormat;
OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) {
    switch (tensor_format) {
        case TensorFormats::NCHW:
            return OprFormat::NCHW;
        case TensorFormats::NCHWc4:
            return OprFormat::NCHW4;
        case TensorFormats::NCHWc8:
            return OprFormat::NCHW8;
        case TensorFormats::NCHWc32:
            return OprFormat::NCHW32;
        case TensorFormats::NCHWc64:
            return OprFormat::NCHW64;
        case TensorFormats::NHWC:
            return OprFormat::NHWC;
        case TensorFormats::CHWNc4:
            return OprFormat::CHWN4;
        default:
M
Megvii Engine Team 已提交
50 51 52
            mgb_throw(
                    MegBrainError, "tensor format(%u) is not supported",
                    static_cast<uint32_t>(tensor_format));
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    }
}

class GraphPartitionProfiler final : public PluginBase {
    using CompNodeEventPtr = std::unique_ptr<CompNode::Event>;

public:
    using OprFilter = thin_function<bool(OperatorNodeBase*)>;
    struct OprKernEvent {
        CompNodeEventPtr start, end;
    };
    GraphPartitionProfiler(ComputingGraph* graph, OprFilter opr_filter);
    ~GraphPartitionProfiler() noexcept;
    float duration_in_usec() const;

private:
    void record_event(CompNodeEventPtr& dest, CompNode cn) {
        if (dest == nullptr)
            dest = cn.create_event(CompNode::Event::NEED_TIMER);
        dest->record();
    }
    ThinHashMap<OperatorNodeBase*, OprKernEvent> m_kern_event;
    OprFilter m_opr_filter;
};

M
Megvii Engine Team 已提交
78 79
GraphPartitionProfiler::GraphPartitionProfiler(
        ComputingGraph* graph, OprFilter opr_filter)
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        : PluginBase(graph), m_opr_filter(opr_filter) {
    using namespace event;
    auto on_before_kern = [this](BeforeKernel const& event) {
        if (!m_opr_filter(event.opr))
            return;
        auto evptr = &m_kern_event[event.opr].start;
        record_event(*evptr, event.comp_node);
    };
    auto on_after_kern = [this](AfterKernel const& event) {
        if (!m_opr_filter(event.opr))
            return;
        auto evptr = &m_kern_event[event.opr].end;
        record_event(*evptr, event.comp_node);
    };
    auto&& ev = graph->event();
    add_event_handler(ev.register_receiver<BeforeKernel>(on_before_kern));
    add_event_handler(ev.register_receiver<AfterKernel>(on_after_kern));
}

GraphPartitionProfiler::~GraphPartitionProfiler() noexcept {
    auto wait = [](const CompNodeEventPtr& ev) {
        if (ev)
            ev->host_wait();
    };
    for (auto&& i : m_kern_event) {
        wait(i.second.start);
        wait(i.second.end);
    }
}

float GraphPartitionProfiler::duration_in_usec() const {
    float device_duration = 0.f;
    for (auto&& kern_ev : m_kern_event) {
        auto&& event = kern_ev.second;
        event.end->host_wait();
        device_duration += 1e6 * event.start->elapsed_time_until(*event.end);
    }
    return device_duration;
}

/*!
 * \brief An operator that indicates its input var node is contiguous
 */
// clang-format off
MGB_DEFINE_OPR_CLASS(MarkInputContiguous, SingleCNOperatorNodeBase) //{
    void scn_do_execute() override {};
    void init_output_static_infer_desc() override;
    void add_input_layout_constraint() override {
        input(0)->add_layout_constraint_contiguous();
    }
public:
    MarkInputContiguous(VarNode* input, const OperatorNodeConfig& config);
    static SymbolVar make(SymbolVar input, const OperatorNodeConfig& config = {});
};
// clang-format on

MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkInputContiguous);

M
Megvii Engine Team 已提交
138 139
MarkInputContiguous::MarkInputContiguous(
        VarNode* input, const OperatorNodeConfig& config)
140 141 142 143 144
        : Super(input->owner_graph(), config, "mark_contiguous", {input}) {
    add_input({input});
    add_output(None);
}

M
Megvii Engine Team 已提交
145 146
SymbolVar MarkInputContiguous::make(SymbolVar input, const OperatorNodeConfig& config) {
    return input.insert_single_output_opr<MarkInputContiguous>(input.node(), config);
147 148 149 150 151
}

void MarkInputContiguous::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto&& mgr = owner_graph()->static_infer_manager();
M
Megvii Engine Team 已提交
152
    mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
153 154 155 156 157 158 159 160 161 162 163 164
}
}  // namespace

/* ================== ProfilerImpl =================*/
class ProfilerImpl final : public ProfilerBase {
public:
    ProfilerImpl(int runs = 10) : m_runs{runs} {};
    ~ProfilerImpl() = default;
    ProfilingResult profile(const Problem& problem) const override;

private:
    static constexpr float PROFILE_TIME_OUT = 1e7;
165
    using ReformatAttribute = ReformatKey::Attribute;
166
    /*!
167 168
     * \brief profile opr format agnostic operators (like elemwise, elemwise
     * multi type, typecvt etc.)
169 170 171
     *
     * \param opr pointer to the operator node to be profiled
     * \param base_format the original tensor format of the operator node.
172
     * \param available_tensor_formats the available tensor formats
173 174 175 176
     * \return the operator node record
     */
    OperatorNodeRecord profile_operator(
            const OperatorNodeBase* opr, TensorFormats base_format,
177
            const SmallVector<TensorFormats>& available_tensor_formats,
M
Megvii Engine Team 已提交
178 179 180 181 182
            ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) const;
    float profile_operator(
            const OperatorNodeBase* opr, TensorFormats base_format,
            TensorFormats tensor_format,
            ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) const;
183
    /*!
184 185
     * \brief profile opr format aware operators (like conv, deconv, conv_bias,
     * etc.)
186 187 188
     *
     * \param opr pointer to the operator node to be profiled
     * \param base_config the tensor formats configuration of base opr format
189
     * \param config all the available configuration
190 191 192 193 194
     * \return the operator node record
     */
    OperatorNodeRecord profile_operator(
            const OperatorNodeBase* opr,
            const OprTensorFormatsConfiguration& base_config,
195
            const SmallVector<OprTensorFormatsConfiguration>& available_configs,
M
Megvii Engine Team 已提交
196 197 198 199 200 201
            ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) const;
    float profile_operator(
            const OperatorNodeBase* opr,
            const OprTensorFormatsConfiguration& base_config,
            const OprTensorFormatsConfiguration& config,
            ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) const;
202 203 204 205
    /*!
     * \brief profile layout transform of the var node
     *
     * \param var pointer to the var node to be profiled
206 207
     * \param base_format the original tensor formats in which the var node is
     * stored \param available_tensor_formats the available tensor formats
208 209 210 211 212 213
     * \param extra_attribute the extra attributes (options) of the problem
     * \return the var node record
     */
    VarNodeRecord profile_var_node(
            const VarNode* var, TensorFormats base_format,
            const SmallVector<TensorFormats>& available_tensor_formats,
M
Megvii Engine Team 已提交
214 215 216 217
            ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) const;
    float profile_var_node(
            const VarNode* var, TensorFormats base_format,
            const ReformatKey& key) const;
218
    int m_runs;  /// sample times of the profiler
219 220 221 222
};

ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
        const OperatorNodeBase* opr, TensorFormats base_format,
223 224
        const SmallVector<TensorFormats>& available_tensor_formats,
        ReformatAttribute extra_attribute) const {
225 226 227 228 229
    OperatorNodeRecord record;
    record.opr = opr;
    auto& costs = record.costs;
    for (auto&& f : available_tensor_formats) {
        auto opr_format = tensor_formats_to_opr_format(f);
M
Megvii Engine Team 已提交
230
        costs[opr_format] = profile_operator(opr, base_format, f, extra_attribute);
231 232 233 234
    }
    return record;
}

M
Megvii Engine Team 已提交
235 236 237
float ProfilerImpl::profile_operator(
        const OperatorNodeBase* opr, TensorFormats base_format,
        TensorFormats tensor_format, ReformatAttribute extra_attribute) const {
238 239 240 241 242 243 244 245 246
    auto graph = ComputingGraph::make();
    graph->options().graph_opt_level = 0;
    graph->options().var_sanity_check_first_run = false;
    VarNodeArray new_inps(opr->input().size());
    for (size_t i = 0; i < opr->input().size(); ++i) {
        auto&& var = opr->input(i);
        auto&& cn = var->comp_node();
        auto&& dtype = var->dtype();
        auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
247 248
        auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape(
                var, base_format, tensor_format, extra_attribute);
249 250 251 252 253 254
        dval->resize(aligned_tensor_shape);
        auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
        new_inps[i] = aligned_var.node();
    }
    auto new_opr = serialization::copy_opr_shallow(
            *opr, new_inps, opr->config(), {graph.get()});
255 256
    if (!m_opr_filter(opr, new_opr))
        return PROFILE_TIME_OUT;
257 258 259 260
    auto y = new_opr->output(0);
    auto mark = MarkInputContiguous::make(SymbolVar(y));
    auto func = graph->compile({{mark, {}}});
    auto filter = [new_opr](OperatorNodeBase* opr) { return opr == new_opr; };
M
Megvii Engine Team 已提交
261 262
    auto profiler =
            std::make_unique<GraphPartitionProfiler>(graph.get(), std::move(filter));
263 264 265 266 267 268
    for (int i = 0; i < m_runs; ++i)
        func->execute();
    return profiler->duration_in_usec();
}

ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
M
Megvii Engine Team 已提交
269
        const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config,
270 271
        const SmallVector<OprTensorFormatsConfiguration>& available_configs,
        ReformatAttribute extra_attribute) const {
272 273 274 275
    OperatorNodeRecord record;
    record.opr = opr;
    auto& costs = record.costs;
    for (auto&& i : available_configs) {
M
Megvii Engine Team 已提交
276
        costs[i.opr_format] = profile_operator(opr, base_config, i, extra_attribute);
277 278 279 280 281
    }
    return record;
}

float ProfilerImpl::profile_operator(
M
Megvii Engine Team 已提交
282
        const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config,
283 284
        const OprTensorFormatsConfiguration& config,
        ReformatAttribute extra_attribute) const {
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
    auto graph = ComputingGraph::make();
    graph->options().graph_opt_level = 0;
    graph->options().var_sanity_check_first_run = false;
    VarNodeArray new_inps(opr->input().size());
    size_t i = 0;
    size_t nr_input_tensor =
            std::min(config.input_tensor_formats.size(), opr->input().size());
    for (; i < nr_input_tensor; ++i) {
        auto&& var = opr->input(i);
        auto&& cn = var->comp_node();
        auto&& dtype = var->dtype();
        auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
        TensorShape aligned_shape;
        if (config.input_tensor_types[i] == TensorType::WEIGHT) {
            mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT);
300
            aligned_shape = ReformatManager::make_aligned_weight_shape(
301
                    var, base_config.input_tensor_formats[i],
M
Megvii Engine Team 已提交
302 303
                    config.input_tensor_formats[i], config.output_tensor_formats[0],
                    extra_attribute);
304
        } else {
M
Megvii Engine Team 已提交
305 306 307
            mgb_assert(
                    base_config.input_tensor_types[i] == config.input_tensor_types[i]);
            mgb_assert(base_config.input_tensor_types[i] == TensorType::FEATURE);
308
            aligned_shape = ReformatManager::make_aligned_tensor_shape(
309
                    var, base_config.input_tensor_formats[i],
310
                    config.input_tensor_formats[i], extra_attribute);
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
        }
        dval->resize(aligned_shape);
        auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
        new_inps[i] = aligned_var.node();
    }
    for (; i < opr->input().size(); ++i) {
        auto&& var = opr->input(i);
        auto&& cn = var->comp_node();
        auto&& dtype = var->dtype();
        auto hval = std::make_shared<HostTensorND>(cn, dtype);
        hval->resize(var->shape());
        auto cb = [&](DeviceTensorND& d) { hval->copy_from(d).sync(); };
        {
            auto cg = var->owner_graph();
            cg->compile({{var, cb}})->execute();
        }
        auto imm = opr::ImmutableTensor::make(*graph, *hval);
        new_inps[i] = imm.node();
    }
M
Megvii Engine Team 已提交
330
    VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr);
331 332 333 334 335 336 337 338 339 340 341
#if 0
    static const ThinHashSet<Typeinfo*> multi_algo_oprs = {
            opr::Convolution::typeinfo(),
            opr::ConvBiasForward::typeinfo(),
            opr::ConvolutionBackwardData::typeinfo(),
            opr::PoolingForward::typeinfo(),
    };
    if (multi_algo_oprs.count(opr->dyn_typeinfo()) &&
        !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr()))
        return PROFILE_TIME_OUT;
#endif
342 343
    if (!m_opr_filter(opr, y->owner_opr()))
        return PROFILE_TIME_OUT;
344 345 346 347
    auto mark = MarkInputContiguous::make(SymbolVar(y));
    auto func = graph->compile({{mark, {}}});
    auto new_opr = y->owner_opr();
    auto filter = [&new_opr](OperatorNodeBase* opr) { return opr == new_opr; };
M
Megvii Engine Team 已提交
348 349
    auto profiler =
            std::make_unique<GraphPartitionProfiler>(graph.get(), std::move(filter));
350 351 352 353 354 355 356 357
    for (int i = 0; i < m_runs; ++i)
        func->execute();
    return profiler->duration_in_usec();
}

ProfilerImpl::VarNodeRecord ProfilerImpl::profile_var_node(
        const VarNode* var, TensorFormats base_format,
        const SmallVector<TensorFormats>& available_tensor_formats,
358
        ReformatAttribute attribute) const {
359 360 361 362 363 364 365
    VarNodeRecord record;
    record.var = var;
    auto& costs = record.costs;
    for (auto&& i : available_tensor_formats) {
        for (auto&& o : available_tensor_formats) {
            if (i == o)
                continue;
M
Megvii Engine Team 已提交
366 367
            ReformatKey key{
                    i, o, attribute, var->dtype().enumv(), var->dtype().enumv()};
368 369 370 371 372 373
            costs[{i, o}] = profile_var_node(var, base_format, key);
        }
    }
    return record;
}

M
Megvii Engine Team 已提交
374 375
float ProfilerImpl::profile_var_node(
        const VarNode* var, TensorFormats base_format, const ReformatKey& key) const {
376 377 378
    auto&& cn = var->comp_node();
    auto&& dtype = var->dtype();
    auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
379 380
    auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape(
            var, base_format, key.input_format, key.attribute);
381 382 383 384 385 386 387 388
    dval->resize(aligned_tensor_shape);
    auto graph = ComputingGraph::make();
    graph->options().graph_opt_level = 0;
    graph->options().var_sanity_check_first_run = false;
    auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
    auto builder = ReformatManager::instance().auto_aligned_reformat_featrue(
            var, base_format, key);
    auto y = builder({aligned_var.node()});
389 390

    if (!m_var_node_filter(var, aligned_tensor_shape, y->shape(), key))
391
        return PROFILE_TIME_OUT;
392 393 394 395 396 397 398
    ThinHashSet<OperatorNodeBase*> set;
    DepOprIter iter([&set](OperatorNodeBase* opr) { set.insert(opr); });
    iter.add(y->owner_opr());
    iter.set_visited(aligned_var.node()->owner_opr());
    auto mark = MarkInputContiguous::make(SymbolVar(y));
    auto func = graph->compile({{mark, {}}});
    auto filter = [&set](OperatorNodeBase* opr) { return set.count(opr) > 0; };
M
Megvii Engine Team 已提交
399 400
    auto profiler =
            std::make_unique<GraphPartitionProfiler>(graph.get(), std::move(filter));
401 402 403 404 405
    for (int i = 0; i < m_runs; ++i)
        func->execute();
    return profiler->duration_in_usec();
}

M
Megvii Engine Team 已提交
406
ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) const {
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
    ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
    {
        auto cb = [&cvprop](OperatorNodeBase* opr) { cvprop.add_opr(opr); };
        DepOprIter iter{cb};
        for (auto&& o : problem.graph_partition().output()) {
            iter.add(o->owner_opr());
        }
    }

    static const ThinHashMap<Typeinfo*, size_t> format_aware_input_tensors = {
#define cb(_Opr, _arity) {_Opr::typeinfo(), _arity}
            cb(Convolution, 2),
            cb(ConvBiasForward, 4),
            cb(ConvolutionBackwardData, 2),
            cb(PoolingForward, 1),
            cb(WarpPerspective, 1),
            cb(Resize, 1),
#undef cb
    };
426
    static const ThinHashSet<Typeinfo*> skip_opr_types = {
M
Megvii Engine Team 已提交
427
            TypeCvt::typeinfo(), Elemwise::typeinfo(), ElemwiseMultiType::typeinfo()};
428 429
    ThinHashSet<VarNode*> vars;
    ThinHashSet<OperatorNodeBase*> oprs;
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
    ThinHashSet<OperatorNodeBase*> skip_oprs;
    for (auto&& opr : problem.graph_partition().all_oprs()) {
        if (cvprop.is_const(opr))
            continue;
        bool skip = true;
        for (auto&& i : opr->input()) {
            skip &= problem.graph_partition().input().count(i) > 0 ||
                    skip_oprs.count(i->owner_opr()) > 0;
        }
        skip &= skip_opr_types.count(opr->dyn_typeinfo());
        if (skip)
            skip_oprs.insert(opr);
        oprs.insert(opr);
        auto find = format_aware_input_tensors.find(opr->dyn_typeinfo());
        if (find == format_aware_input_tensors.end()) {
            for (auto&& i : opr->input()) {
                if (!cvprop.is_const(i)) {
                    vars.insert(i);
448
                }
449 450
            }
        } else {
M
Megvii Engine Team 已提交
451
            size_t nr_input_tensor = std::min(find->second, opr->input().size());
452 453 454
            for (size_t i = 0; i < nr_input_tensor; ++i) {
                if (!cvprop.is_const(opr->input(i))) {
                    vars.insert(opr->input(i));
455 456 457
                }
            }
        }
458 459
        for (auto&& ov : opr->usable_output()) {
            vars.insert(ov);
460 461 462 463 464
        }
    }

    auto base_format = problem.base_format();
    auto&& available_tensor_formats = problem.available_tensor_formats();
465
    auto&& reformat_attribute = problem.attribute().reformat_attribute;
466 467 468 469 470

    ProfilingResult profiling_result;
    auto& opr_record = profiling_result.opr_record;
    auto& var_record = profiling_result.var_record;
    for (auto&& var : vars) {
471 472
        var_record[var] = profile_var_node(
                var, base_format, available_tensor_formats, reformat_attribute);
473 474 475 476 477
    }
    for (auto&& opr : oprs) {
        auto&& opr_configs = problem.opr_configs();
        auto find = opr_configs.find(opr->dyn_typeinfo());
        if (find == opr_configs.end()) {
478 479
            if (skip_oprs.count(opr) > 0) {
                SmallVector<TensorFormats> tensor_formats = {base_format};
480 481
                opr_record[opr] = profile_operator(
                        opr, base_format, tensor_formats, reformat_attribute);
482
            } else {
M
Megvii Engine Team 已提交
483 484
                opr_record[opr] = profile_operator(
                        opr, base_format, available_tensor_formats, reformat_attribute);
485
            }
486 487 488 489 490 491 492 493 494 495
        } else {
            auto&& dispatchers = find->second;
            SmallVector<OprTensorFormatsConfiguration> configs;
            for (const auto& item : dispatchers) {
                auto config = (*item.second)(opr);
                if (config.valid()) {
                    configs.emplace_back(config.val());
                }
            }
            auto base_config = problem.base_config(opr);
M
Megvii Engine Team 已提交
496 497
            opr_record[opr] =
                    profile_operator(opr, base_config, configs, reformat_attribute);
498 499 500 501 502 503 504 505 506 507 508 509
        }
    }
    for (auto&& rpair : opr_record) {
        mgb_log_debug("%s", rpair.second.to_string().c_str());
    }
    for (auto&& rpair : var_record) {
        mgb_log_debug("%s", rpair.second.to_string().c_str());
    }
    return profiling_result;
}

/* ================== ProfilerBase =================*/
510
ProfilerBase::ProfilerBase(float opr_threshold, float var_node_threshold)
M
Megvii Engine Team 已提交
511 512
        : m_opr_threshold{opr_threshold}, m_var_node_threshold{var_node_threshold} {
    m_opr_filter = [this](const OperatorNodeBase* opr, OperatorNodeBase* new_opr) {
513 514 515 516 517
        /// \note: for the considerations of performance, we skip nchw(naive)
        /// kernels for conv bias on CUDA platform. to remove this later
        if (auto conv = try_cast_as_op<opr::ConvBiasForward>(new_opr)) {
            if (conv->output(0)->comp_node().device_type() ==
                        CompNode::DeviceType::CUDA &&
M
Megvii Engine Team 已提交
518
                conv->input(0)->dtype().category() == DTypeCategory::QUANTIZED &&
519 520 521 522
                conv->param().format == OprFormat::NCHW) {
                return false;
            }
        }
M
Megvii Engine Team 已提交
523 524
        float comp1 =
                m_opr_footprint.get_computation(const_cast<OperatorNodeBase*>(opr));
525 526 527 528 529
        float comp2 = m_opr_footprint.get_computation(new_opr);
        if (comp2 > m_opr_threshold * comp1)
            return false;
        return true;
    };
M
Megvii Engine Team 已提交
530 531
    m_var_node_filter = [this](const VarNode* var, TensorShape from, TensorShape to,
                               ReformatKey key) {
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
        /// \note: due to the alignment requirement of low-bit tensor, we skip
        /// some layout transform for low-bit tensors. The skipped layout
        /// transforms do not have corresponding dnn kernel and cannot be
        /// implemented by tensor manip operators (like reshape, dimshuffle,
        /// subtensor, etc.).
        if (var->dtype().enumv() == DTypeEnum::QuantizedS4 ||
            var->dtype().enumv() == DTypeEnum::Quantized4Asymm) {
            if (key.input_format == TensorFormats::NCHW &&
                key.output_format != TensorFormats::NHWC &&
                key.output_format != TensorFormats::NCHWc64) {
                return false;
            }
            if (key.output_format == TensorFormats::NCHW &&
                key.input_format != TensorFormats::NHWC &&
                key.input_format != TensorFormats::NCHWc64) {
                return false;
            }
549
        }
550 551
        TensorLayout orig_ly = {var->shape(), var->dtype()},
                     from_ly = {from, var->dtype()}, to_ly = {to, var->dtype()};
552
        float orig_memory = orig_ly.span().dist_byte() * 2.f;
M
Megvii Engine Team 已提交
553
        float reformat_memory = from_ly.span().dist_byte() + to_ly.span().dist_byte();
554 555 556 557 558 559
        if (reformat_memory > orig_memory * m_var_node_threshold)
            return false;
        return true;
    };
}

560
std::string ProfilerBase::OperatorNodeRecord::to_string() const {
M
Megvii Engine Team 已提交
561 562 563
    auto str = ssprintf(
            "\nopr type: %s\nopr name: %s\ninputs:\n", opr->dyn_typeinfo()->name,
            opr->cname());
564
    for (auto&& i : opr->input()) {
M
Megvii Engine Team 已提交
565 566
        str += ssprintf(
                "\tvar: %s\n\tshape: %s\n", i->cname(), i->shape().to_string().c_str());
567
    }
M
Megvii Engine Team 已提交
568 569 570
    str += ssprintf(
            "outputs:\n\tvar: %s\n\tshape: %s\ncosts:\n", opr->output(0)->cname(),
            opr->output(0)->shape().to_string().c_str());
571
    for (auto&& cpair : costs) {
M
Megvii Engine Team 已提交
572 573 574
        str += ssprintf(
                "\tformat: %s; cost:%f", opr_format_to_string(cpair.first),
                cpair.second);
575 576 577 578 579 580 581 582
    }
    return str;
}

std::string ProfilerBase::VarNodeRecord::to_string() const {
    auto str = ssprintf("\nvar: %s\ncosts:", var->cname());
    for (auto&& cpair : costs) {
        auto&& formats = cpair.first;
M
Megvii Engine Team 已提交
583 584 585 586 587 588 589
        str += ssprintf(
                "\n\tformat: (i:%s;o:%s); cost:%f",
                tensor_formats_to_named_tensor_shape(formats.first).to_string().c_str(),
                tensor_formats_to_named_tensor_shape(formats.second)
                        .to_string()
                        .c_str(),
                cpair.second);
590 591 592 593 594 595 596 597 598
    }
    return str;
}

std::unique_ptr<ProfilerBase> ProfilerBase::make_profiler() {
    return std::make_unique<ProfilerImpl>();
}

// vim: syntax=cpp.doxygen