framework.cpp 31.3 KB
Newer Older
1 2 3 4
/**
 * \file src/gopt/impl/framework.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15
 */

#include "megbrain/gopt/framework.h"
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/gtrans.h"
16 17
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/misc.h"
18
#include "megbrain/graph/cg.h"
19 20 21
#include "megbrain/graph/event.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/opr_shallow_copy.h"
22
#include "megbrain/serialization/serializer.h"
23 24 25 26 27 28 29 30 31 32
#include "megbrain/utils/timer.h"

#if MGB_JIT
#include "megbrain/jit/fusion_pass.h"
#endif

#if MGB_ENABLE_TENSOR_RT
#include "megbrain/tensorrt/opr_replace.h"
#endif

33 34 35 36
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/gopt/layout_transform_pass.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/gopt/solver.h"
37

38 39 40 41 42
using namespace mgb;
using namespace gopt;

/* ================ SubGraph ================ */

M
Megvii Engine Team 已提交
43
OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(OperatorNodeBase* opr) {
44
    auto&& new_inp = m_opr_new_inp_cache;
45 46 47 48
    new_inp.clear();
    new_inp.reserve(opr->input().size());
    bool has_replaced_inp = false;

49
    for (auto i : opr->input()) {
50 51 52 53 54 55 56 57 58 59
        auto new_var = get_var(i);
        if (new_var != i) {
            has_replaced_inp = true;
            new_inp.push_back(new_var);
        } else {
            new_inp.push_back(i);
        }
    }

    if (has_replaced_inp) {
M
Megvii Engine Team 已提交
60
        auto new_opr = serialization::copy_opr_shallow(*opr, new_inp, opr->config());
61 62 63
        auto &&out0 = opr->output(), &&out1 = new_opr->output();
        size_t i = 0;
        auto err_msg = [opr, new_opr] {
M
Megvii Engine Team 已提交
64 65 66 67
            return ssprintf(
                    "bad opr copy: src=%s{%s} dst=%s{%s}", opr->cname(),
                    opr->dyn_typeinfo()->name, new_opr->cname(),
                    new_opr->dyn_typeinfo()->name);
68 69 70 71 72 73 74
        };
        MGB_MARK_USED_VAR(err_msg);
        // opr output size mismatch may be caused by:
        //     0) inplace arith optimization (e.g. PowC need an extra workspace)
        //     1) other post-insert optimization (e.g. const folding)
        // we can't handle only usable_output here, since some output var with
        // volatile flag could be the graph's endpoint (e.g. RemoteSend)
75
        for (; i < std::min(out0.size(), out1.size()); ++i) {
76 77 78 79
            bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
                 v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT);
            mgb_assert(v0 == v1, "%s", err_msg().c_str());

80
            auto&& ins = m_varmap.insert({out0[i], {true, nullptr}});
M
Megvii Engine Team 已提交
81 82 83
            mgb_assert(
                    ins.second || ins.first->second.first,
                    "opr output already replaced");
84 85 86 87
            // handle repeated call on the same opr
            ins.first->second.second = out1[i];
            on_var_replaced(out0[i], out1[i], nullptr);
        }
88
        for (; i < out0.size(); ++i) {
M
Megvii Engine Team 已提交
89 90 91
            mgb_assert(
                    out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), "%s",
                    err_msg().c_str());
92
        }
93
        for (; i < out1.size(); ++i) {
M
Megvii Engine Team 已提交
94 95 96
            mgb_assert(
                    out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), "%s",
                    err_msg().c_str());
97 98 99 100 101 102
        }
        return new_opr;
    }
    return opr;
}

M
Megvii Engine Team 已提交
103
void SubGraph::Rewriter::replace_var(VarNode* src, VarNode* dst, const char* msg) {
104 105 106 107 108 109 110 111 112
    if (src == dst)
        return;

    // Optimizers should not create a loop in varaible replace map.
    mgb_throw_if(
            get_var_internal(dst).second == src, InternalError,
            "dst %s maps back to src %s in SubGraph::Rewriter::replace_var",
            dst->cname(), src->cname());

113
    auto&& ins = m_varmap.insert({src, {false, dst}});
114
    if (!ins.second) {
115
        auto&& old_rep = ins.first->second;
M
Megvii Engine Team 已提交
116 117
        mgb_assert(
                old_rep.first || old_rep.second == dst, "can not replace a var twice");
118 119 120 121 122 123
        old_rep.first = false;
        old_rep.second = dst;
    }
    on_var_replaced(src, dst, msg);
}

M
Megvii Engine Team 已提交
124
void SubGraph::Rewriter::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) {
125 126 127 128 129 130 131 132
    if (auto state = m_owner_graph->owner_opt_state()) {
        state->on_var_replaced(src, dst, msg);
    }
}

void SubGraph::Rewriter::apply_inplace() const {
    m_owner_graph->m_endpoint_oprs.clear();
    m_owner_graph->m_endpoint_vars_set.clear();
133
    for (auto&& var : m_owner_graph->m_endpoint_vars) {
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
        var = get_var(var.node());
        m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr());
        m_owner_graph->m_endpoint_vars_set.insert(var.node());
    }
}

std::pair<bool, VarNode*> SubGraph::Rewriter::get_var_internal(VarNode* var) {
    // The implementation is (manually) unrolled once, background:
    // git-core/brain-sdk/MegBrain/merge_requests/486#note_76971
    auto it = m_varmap.find(var);
    if (it == m_varmap.end()) {
        return {true, var};
    }
    mgb_assert(it->second.second != var, "loop detected in m_varmap");
    auto it_next = m_varmap.find(it->second.second);
    if (it_next == m_varmap.end()) {
        return it->second;
    }
M
Megvii Engine Team 已提交
152 153
    mgb_assert(
            it_next->second.second != it->second.second, "loop detected in m_varmap");
154 155 156 157 158
    auto next = get_var_internal(it_next->second.second);
    it_next->second = {next.first & it_next->second.first, next.second};
    return it->second = {it_next->second.first & it->second.first, next.second};
}

159 160
SubGraph::SubGraph(const SymbolVarArray& endpoint_vars)
        : m_endpoint_vars(endpoint_vars) {
161 162
    mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty");
    m_comp_graph = endpoint_vars[0].node()->owner_graph();
163
    for (auto i : endpoint_vars) {
164 165
        m_endpoint_oprs.insert(i.node()->owner_opr());
        m_endpoint_vars_set.insert(i.node());
M
Megvii Engine Team 已提交
166 167 168
        mgb_assert(
                m_comp_graph == i.node()->owner_graph(),
                "endpoints belong to different computing graphs");
169 170 171
    }
}

M
Megvii Engine Team 已提交
172
void SubGraph::iter(const Callback& cb, std::shared_ptr<ExtraDep> extra_dep) const {
173 174 175
    Callback on_opr;

    if (m_owner_opt_state) {
176
        on_opr = [state = m_owner_opt_state, &cb](OperatorNodeBase* opr) {
177 178
            state->m_opr_property_flag = OprPropertyFlag::ALL;
            state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr);
M
Megvii Engine Team 已提交
179
            state->m_cur_iter_opr_priority = opr->node_prop().attribute().priority;
180
            state->m_cur_iter_opr_stream_prop_type =
181
                    state->m_comp_node_opt.stream_prop_type(opr->output(0));
182 183 184 185 186 187 188 189 190 191 192
            mgb_assert(state->m_oprs_inserted.empty());
            cb(opr);
            state->m_opr_property_flag = OprPropertyFlag::NONE;
            state->m_cur_iter_src_opr = nullptr;
            state->m_oprs_inserted.clear();
        };
    } else {
        on_opr = cb;
    }

    cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)};
193
    for (auto i : m_endpoint_oprs)
194 195 196 197 198
        dep_iter.add(i);
}

ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {
    ThinHashMap<VarNode*, size_t> ret;
199 200
    auto cb = [&](OperatorNodeBase* opr) {
        for (auto&& i : opr->node_prop().dep_map()) {
201
            if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) {
202
                ++ret.at(i.first);
203 204
            }
        }
205
        for (auto i : opr->output()) {
206 207 208 209 210 211 212
            if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
                auto ins = ret.insert({i, 0});
                mgb_assert(ins.second);
            }
        }
    };
    iter(cb);
213
    for (auto i : m_endpoint_vars_set) {
214 215 216 217 218
        auto iter = ret.find(i);
        if (iter == ret.end()) {
            mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
            ret[i] = 1;
        } else {
219
            ++ret.at(i);
220 221 222 223 224 225 226
        }
    }
    return ret;
}

/* ================ UniqReaderCheck ================ */

227 228
UniqReaderCheck::UniqReaderCheck(const SubGraph& graph)
        : m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} {}
229

M
Megvii Engine Team 已提交
230 231
void UniqReaderCheck::update_on_opr_auto_replace(
        OperatorNodeBase* opr, OperatorNodeBase* repl_opr) {
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    auto non_volatile_size = [](const VarNodeArray& vars) -> size_t {
        size_t size = 0;
        for (size_t i = 0; i < vars.size(); ++i) {
            if (!vars[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
                size++;
            }
        }
        return size;
    };
    if (opr != repl_opr) {
        auto &&o0 = opr->output(), &&o1 = repl_opr->output();
        mgb_assert(non_volatile_size(o0) == non_volatile_size(o1));
        for (size_t i = 0; i < o0.size(); ++i) {
            auto iter = m_var2nr_val_dep.find(o0[i]);
            if (iter != m_var2nr_val_dep.end()) {
                auto n = iter->second;
                m_var2nr_val_dep[o1[i]] = n;
            }
        }
    }
}

/* ================ OptState ================ */

256 257 258 259 260 261
OptState::OptState(const GraphOptimizer* owner_optimizer, const SubGraph& graph)
        : m_owner_optimizer{owner_optimizer},
          m_var_replace_map{const_cast<ThinHashMap<VarNode*, VarNode*>*>(
                  &GraphOptimizer::var_replace_map(*graph.comp_graph()))},
          m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()},
          m_graph{graph} {
262 263 264 265 266
    mgb_assert(!m_graph.m_owner_opt_state);
    m_var_replace_map->clear();
    m_graph.m_owner_opt_state = this;
    m_oprs_inserted.clear();

267
    auto on_opr_insert = [this](const cg::event::OprInserted& ev) {
268 269 270
        auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR,
             need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY;
        if (need_src_opr)
M
Megvii Engine Team 已提交
271 272 273 274 275
            mgb_assert(
                    m_cur_iter_src_opr,
                    "opr %s{%s} created outside from "
                    "SubGraph::iter",
                    ev.opr->cname(), ev.opr->dyn_typeinfo()->name);
276 277 278
        if (ev.exc || ev.is_dedup)
            return;

279 280
        auto&& new_attr = ev.opr->node_prop().attribute();
        auto&& ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE});
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
        mgb_assert(ins.second);

        if (need_src_opr && !new_attr.src_opr) {
            auto src_opr = m_cur_iter_src_opr;
            if (ev.opr != src_opr)
                new_attr.src_opr = src_opr;
            ins.first->second |= OprPropertyFlag::SOURCE_OPR;
        }
        if (need_priority) {
            new_attr.priority = m_cur_iter_opr_priority;
            if (!ev.opr->update_priority()) {
                ins.first->second |= OprPropertyFlag::PRIORITY;
            }
        }

        auto csp = m_cur_iter_opr_stream_prop_type;
        if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) {
298
            for (auto i : ev.opr->output())
299 300 301
                m_comp_node_opt.register_stream_var(i, csp);
        }
    };
302
    m_on_opr_insert_handler =
M
Megvii Engine Team 已提交
303 304
            graph.comp_graph()->event().register_receiver<cg::event::OprInserted>(
                    on_opr_insert);
305 306
}

307
void OptState::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) {
308 309
    if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
        // this can only happen in auto_replace_outputs()
M
Megvii Engine Team 已提交
310 311 312
        mgb_assert(
                dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) &&
                src->owner_opr()->dyn_typeinfo() == dst->owner_opr()->dyn_typeinfo());
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
        mgb_assert(!msg);
        return;
    }

    //! check_property
    {
        auto iter = m_oprs_inserted.find(dst->owner_opr());
        if (iter != m_oprs_inserted.end()) {
            auto &&src_attr = src->owner_opr()->node_prop().attribute(),
                 &&dst_attr = dst->owner_opr()->node_prop().attribute();
            auto opr_info = [&](OperatorNodeBase* opr) {
                return opr ? opr->name() + "(" + std::to_string(opr->id()) + ")"
                           : "NULL";
            };
            auto err_msg = [&] {
                std::string ret = "Please contact Engine group:\n";
                ret += "src opr: ";
                ret += opr_info(src->owner_opr());
                ret += ", dst opr: ";
                ret += opr_info(dst->owner_opr());
                return ret;
            };
            MGB_MARK_USED_VAR(err_msg);
            if (iter->second & OprPropertyFlag::SOURCE_OPR) {
                auto &&src_rt = get_opr_root_source_opr(src->owner_opr()),
                     &&dst_rt = get_opr_root_source_opr(dst->owner_opr());
M
Megvii Engine Team 已提交
339 340 341 342 343
                mgb_assert(
                        dst_rt == src_rt,
                        "%s\nsrc source_opr: %s, dst source_opr: %s\n",
                        err_msg().c_str(), opr_info(src_rt).c_str(),
                        opr_info(dst_rt).c_str());
344 345
            }
            if (iter->second & OprPropertyFlag::PRIORITY) {
M
Megvii Engine Team 已提交
346 347 348 349
                mgb_assert(
                        src_attr.priority == dst_attr.priority,
                        "%s\nsrc priority: %d, dst priority %d\n", err_msg().c_str(),
                        src_attr.priority, dst_attr.priority);
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
            }
        }
    }

    {
        bool suc = true;
        SmallVector<std::string> fail_chks;
        if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_INFER_TYPE) {
            auto&& mgr = src->owner_graph()->static_infer_manager();
            auto it0 = mgr.get_infer_type(src), it1 = mgr.get_infer_type(dst);
            using cg::static_infer::InferType;
            // only check wheter inferable
            auto norm = [](InferType::Flag f) -> bool {
                return f & (InferType::RT_STATIC | InferType::CONST);
            };
            if (!(norm(it0.shape) == norm(it1.shape) &&
366
                  norm(it0.value) <= norm(it1.value))) {
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
                suc = false;
                fail_chks.push_back("infer-type");
            }
        }
        if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_DTYPE) {
            if (src->dtype() != dst->dtype()) {
                suc = false;
                fail_chks.push_back("dtype");
            }
        }
        if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_SHAPE) {
            if (!(src->shape().eq_shape(dst->shape()))) {
                suc = false;
                fail_chks.push_back("shape");
            }
        }
        if (!suc) {
            std::string fail_msg = "{";
            for (size_t i = 0; i < fail_chks.size(); i++) {
                fail_msg += fail_chks[i];
                if (i < fail_chks.size() - 1) {
                    fail_msg += ",";
                }
            }
            fail_msg += "}";
            mgb_throw_raw(
                    cg::OperatorNodeExcExtraInfo::ExcMaker{src->owner_opr()}
                            .make<InternalError>(ssprintf(
M
Megvii Engine Team 已提交
395
                                    "%s mismatch for replace_var: %s", fail_msg.c_str(),
396 397 398 399 400 401 402 403 404 405 406 407 408 409
                                    cg::dump_var_info({src, dst}).c_str())));
        }
    }

    if (src->has_name_set() && !dst->has_name_set()) {
        dst->name(src->name());
    }
    (*m_var_replace_map)[src] = dst;
    // dst should be considered as newly inserted, and previous replace
    // record should be ignored
    m_var_replace_map->erase(dst);

#if MGB_ENABLE_LOGGING
    if (msg && m_owner_optimizer->verbosity()) {
410 411 412 413 414 415 416 417 418 419 420
        m_log_msg.append("\n ")
                .append(std::to_string(m_log_nr_item))
                .append(": ")
                .append(src->owner_opr()->cname())
                .append(" => ")
                .append(dst->owner_opr()->cname())
                .append(" (")
                .append(msg)
                .append(")");
    }
    ++m_log_nr_item;
421 422 423
#endif
}

424
size_t OptState::flush_log(const char* title) {
425 426 427 428 429 430 431 432 433 434 435 436
    if (m_owner_optimizer->verbosity() >= 2) {
        if (m_log_msg.empty()) {
            m_log_msg = mgb_cstr_log(" no var replacement logged");
        }
        mgb_log("%s%s", title, m_log_msg.c_str());
        m_log_msg.clear();
    }
    auto ret = m_log_nr_item;
    m_log_nr_item = 0;
    return ret;
}

M
Megvii Engine Team 已提交
437 438 439
void OptState::call_with_opr(
        OperatorNodeBase* opr, thin_function<void(void)> func,
        OprPropertyFlag opr_property_flag) {
440 441 442 443 444
    auto src_opr = cg::get_opr_root_source_opr(opr);
    auto opr_priority = opr->node_prop().attribute().priority;
    auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0));
    ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted;

445 446 447 448 449 450 451 452 453 454 455 456 457
    auto swap_properties =
            [&, need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR,
             need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] {
                if (need_src_opr) {
                    std::swap(m_cur_iter_src_opr, src_opr);
                }
                if (need_priority) {
                    std::swap(m_cur_iter_opr_priority, opr_priority);
                }
                std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type);
                std::swap(m_opr_property_flag, opr_property_flag);
                std::swap(m_oprs_inserted, oprs_inserted);
            };
458 459 460
    MGB_TRY {
        swap_properties();
        func();
461 462
    }
    MGB_FINALLY({ swap_properties(); });
463 464 465
}

/* ================ RecursiveSubGraphRewriteHelper ================ */
M
Megvii Engine Team 已提交
466
RecursiveSubGraphRewriteHelper::~RecursiveSubGraphRewriteHelper() noexcept = default;
467

468 469
RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState& state)
        : m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} {}
470 471 472 473 474 475 476 477

void RecursiveSubGraphRewriteHelper::apply() {
    using namespace std::placeholders;
    m_opt_state.graph().iter(
            std::bind(&RecursiveSubGraphRewriteHelper::on_opr, this, _1));
    m_rewriter.apply_inplace();
}

478 479
void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase* opr) {
    auto on_new_opr = [this](OperatorNodeBase* opr) {
480 481 482 483 484 485 486 487 488 489 490 491
        auto repl_opr = m_rewriter.auto_replace_outputs(opr);
        return on_new_opr_check_should_process(opr, repl_opr);
    };

    if (!on_new_opr(opr))
        return;

    auto orig_out = get_opr_single_output_var(opr);
    if (!orig_out)
        return;

    mgb_assert(m_opr_stack.empty());
M
Megvii Engine Team 已提交
492
    m_opr_stack.push_back({orig_out, m_rewriter.get_var(orig_out)->owner_opr()});
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512

    bool first = true;
    while (!m_opr_stack.empty()) {
        auto cur_frame = m_opr_stack.back();
        m_opr_stack.pop_back();
        auto cur_opr = cur_frame.opr;
        bool should_process;
        if (first) {
            should_process = true;
            first = false;
        } else {
            should_process = on_new_opr(cur_opr);
        }
        auto cur_out = get_opr_single_output_var(cur_opr);
        mgb_assert(cur_out);
        cur_out = m_rewriter.get_var(cur_out);

        if (should_process) {
            auto trans = process_opr(cur_out);
            if (trans.valid()) {
M
Megvii Engine Team 已提交
513
                m_opr_stack.push_back({cur_frame.orig_var, trans->result->owner_opr()});
514
                for (auto i : reverse_adaptor(trans->internal)) {
515 516 517 518 519 520 521 522 523 524 525 526 527 528
                    if (i)
                        m_opr_stack.push_back({i, i->owner_opr()});
                }
                if (trans->msg) {
                    if (!m_log_msg.empty())
                        m_log_msg.push_back(';');
                    m_log_msg.append(trans->msg);
                }
                continue;
            }
        }

        auto src = cur_frame.orig_var;
        if (m_rewriter.get_var(src) != cur_out) {
529
            const char* msg = nullptr;
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
            if (m_opr_stack.empty()) {
                msg = m_log_msg.c_str();
            }
            m_rewriter.replace_var(src, cur_out, msg);
            after_replace_var(src, cur_out);
            if (m_opr_stack.empty()) {
                m_log_msg.clear();
                break;
            }
        }
    }
}

/* ================ GraphOptimizer ================ */

GraphOptimizer::~GraphOptimizer() noexcept = default;

M
Megvii Engine Team 已提交
547
class GraphOptimizer::VarReplaceMapStorage : public UserDataContainer::UserData {
548 549
    MGB_TYPEINFO_OBJ_DECL;

550 551
public:
    ThinHashMap<VarNode*, VarNode*> map;
552 553 554 555 556 557 558 559 560 561
};
MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage);

GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr<Pass> pass) {
    mgb_assert(!pass->m_owner_optimizer);
    pass->m_owner_optimizer = this;
    m_passes.emplace_back(std::move(pass));
    return *this;
}

562
SubGraph GraphOptimizer::apply(const SubGraph& graph) const {
563 564 565 566 567 568 569 570
    RealTimer timer;
    OptState state{this, graph};

    size_t tot_nr_replace = 0;

    // first update output var shapes of all oprs
    state.graph().iter(cg::update_output_var_shapes);

571
    auto&& opt = graph.comp_graph()->options();
572
    auto orig_setting = opt.graph_opt_level;
573
    Pass* cur_pass = nullptr;
574 575
    MGB_MARK_USED_VAR(cur_pass);
    MGB_TRY {
576
        for (auto&& i : m_passes) {
577 578 579 580 581
            state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL);
            cur_pass = i.get();
            opt.graph_opt_level = 1;
            i->apply(state);
            tot_nr_replace += state.flush_log(
M
Megvii Engine Team 已提交
582
                    mgb_ssprintf_log("apply optimization pass %s:", i->name()).c_str());
583
        }
584 585
    }
    MGB_CATCH(std::exception & exc, {
M
Megvii Engine Team 已提交
586 587 588
        mgb_log_error(
                "error while applying optimization pass %s: %s", cur_pass->name(),
                exc.what());
589 590 591
        opt.graph_opt_level = orig_setting;
        throw;
    })
592
    MGB_FINALLY(opt.graph_opt_level = orig_setting);
593
    if (verbosity() >= 1) {
594 595
        mgb_log_debug(
                "graph optimization: applied %zu passes, "
596 597 598 599 600 601
                "total %zu var(s) replaced; time=%.2fms",
                m_passes.size(), tot_nr_replace, timer.get_msecs());
    }
    return state.graph();
}

602
const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray& vars) const {
603 604 605 606 607 608 609
    if (m_passes.empty()) {
        // this check is necessary, since OptState would clear
        // var_replace_map()
        return *this;
    }

    auto g = apply({{vars.begin(), vars.end()}});
610
    for (size_t i = 0; i < vars.size(); ++i) {
611 612 613 614 615 616 617 618
        vars[i] = g.endpoint_vars()[i].node();
    }
    return *this;
}

GraphOptimizer& GraphOptimizer::add_preset_passes(
        bool after_grad, const OptimizeForInferenceOptions* inference_opt,
        const ComputingGraph::Options* comp_graph_opt) {
M
Megvii Engine Team 已提交
619 620
    auto cv_type =
            inference_opt ? ConstVarType::IMMUTABLE_AND_PARAM : ConstVarType::IMMUTABLE;
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
    if (inference_opt) {
        add_pass<ConvertBatchNormToElemwisePass>();
    }
    if (!after_grad || inference_opt) {
        add_pass<CondExecConstPredicateFolding>();
    }
    if (after_grad || inference_opt) {
        add_pass<RemoveNonComputingOprPass>();
    }
    add_pass<DelayBroadcastPass>();
    add_pass<ExpandFusedArithPass>();
    add_pass<NormalizeArithChainPass>();
    if (inference_opt) {
        add_pass<ParamRedistributePass>();
        add_pass<ParamFusePass>();
    }
    add_pass<ArithMulDistributePass>();
    add_pass<ReorderArithChainPass>(cv_type);

    add_pass<ArithFusePass>();
    // reorder again because shapes of fused oprs might change
    add_pass<ReorderArithChainPass>(cv_type);
    add_pass<FinalArithTransformPass>();
    add_pass<RemoveRedundantTypeCvtPass>();
645
    add_pass<RemoveRedundantCopyPass>();
646 647

#if MGB_JIT
648 649 650 651 652 653 654 655 656 657 658 659
    using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
    int jit_opt_level = 0;
    JITConfig jit_config;

    // for more detail on what is happening here, see comments on the
    // constuctor of class JITFusionPass in fusion_pass.h
    if (comp_graph_opt) {
        jit_opt_level = comp_graph_opt->graph_opt.jit;
        if (comp_graph_opt->graph_opt_level >= 3) {
            jit_opt_level = std::max(jit_opt_level, 1);
        }
        jit_config = comp_graph_opt->graph_opt.jit_config;
660
    }
661 662
    bool need_jit = (jit_opt_level > 0) || jit_config.enabled();

663 664 665 666 667 668 669 670 671 672 673 674
    if (need_jit && after_grad) {
        add_pass<gopt::RecompTypeCvtPass>();
    }
#endif

    // combine astype and reduce.
    // Note: apply this pass before JITFusion, so the TypeCvt which
    // read by both Reduce and Elemwise could be fused correctly.
    add_pass<CombineAstypeAndReducePass>();

#if MGB_JIT
    if (need_jit) {
675
        add_pass<gopt::JITFusionPass>(after_grad, jit_opt_level, jit_config);
676 677 678
    }
#endif

679 680
    if (inference_opt) {
        add_pass<ParamFusePass>();
681
        add_passes_for_optimize_options(*inference_opt);
682 683
    }

684 685 686 687 688
    if (inference_opt) {
        // merge params to reduce loading time and graph overhead
        add_pass<ParamMergePass>();
        add_pass<FuseDeconvCvtPass>();
    }
689 690 691 692 693

    if (inference_opt) {
        // remove shape hint after inference optimization
        add_pass<RemoveShapeHintPass>();
    }
694 695 696 697
    return *this;
}

const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map(
698 699
        ComputingGraph& graph) {
    auto storage =
M
Megvii Engine Team 已提交
700
            graph.options().user_data.get_user_data_or_create<VarReplaceMapStorage>();
701 702 703
    return storage->map;
}

704 705 706
VarNode* GraphOptimizer::var_replace_lookup(VarNode* var) {
    auto&& map = var_replace_map(*(var->owner_graph()));
    for (;;) {
707 708 709 710 711 712 713
        auto iter = map.find(var);
        if (iter == map.end())
            return var;
        var = iter->second;
    }
}

714
const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
715
        const cg::GraphCommonOptimizeOptions& options) {
716 717 718 719 720 721
    return add_passes_for_optimize_options(
            const_cast<cg::GraphCommonOptimizeOptions&>(options));
}

const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
        cg::GraphCommonOptimizeOptions& options, bool reset) {
722
    bool need_param_fuse = false;
723 724 725 726 727 728 729

#define cb(_option, _passes)             \
    if (options.has_set_##_option()) {   \
        _passes need_param_fuse = true;  \
        if (reset) {                     \
            options.disable_##_option(); \
        }                                \
730
    }
731 732 733 734 735

    cb(fuse_preprocess, {
        add_pass(FuseNCHW4Int8Preprocess::make());
        add_pass<FuseWarpPerspectiveDimshufflePass>();
    });
736 737 738
    cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
    cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });

739 740 741 742 743 744 745
    cb(nchw4, {
        add_pass<FuseConvBiasNonlinPass>();
        add_pass<FuseConvBiasZPass>();
        add_pass(EnableNCHW4Pass::make_nchw4_converter());
        add_pass<ShuffleShuffleRemovePass>();
        add_pass<RemoveRedundantTypeCvtPass>();
    });
746
    cb(nhwcd4, {
747
        add_pass<FuseConvBiasNonlinPass>();
748
        add_pass(ConvertFormatPass::make_nhwcd4_converter());
749
    });
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
    cb(nchw88, {
        add_pass<FuseConvBiasNonlinPass>();
        add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
        add_pass<ShuffleShuffleRemovePass>();
    });
    cb(nchw44, {
        add_pass<FuseConvBiasNonlinPass>();
        add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
        add_pass<ShuffleShuffleRemovePass>();
    });
    cb(nchw44_dot, {
        add_pass<FuseConvBiasNonlinPass>();
        add_pass(EnableNchw44DotPass::make_nchw44_dot_converter());
        add_pass<ShuffleShuffleRemovePass>();
    });
765
    cb(nchw32, {
766
        add_pass<FuseConvBiasNonlinPass>();
767
        add_pass<FuseConvBiasZPass>();
768
        add_pass(EnableNCHW4Pass::make_nchw4_converter());
769 770 771
        add_pass(EnableTensorCorePass::make_tensorcore_converter());
        add_pass<ShuffleShuffleRemovePass>();
        add_pass<RemoveRedundantTypeCvtPass>();
772
        add_pass(FuseNCHW4Int8Preprocess::make());
773
        add_pass<FuseWarpPerspectiveDimshufflePass>();
774
#if CUDA_VERSION >= 10020
775
        add_pass<FoldingConvBiasDimshufflePass>();
776
#endif
777 778
    });
    cb(chwn4, {
779 780
        add_pass<FuseConvBiasNonlinPass>();
        add_pass<FuseConvBiasZPass>();
781
        add_pass(EnableNCHW4Pass::make_nchw4_converter());
782 783 784
        add_pass(EnableCHWN4Pass::make_chwn4_converter());
        add_pass<ShuffleShuffleRemovePass>();
        add_pass<RemoveRedundantTypeCvtPass>();
785
    });
786 787 788 789 790 791 792 793 794
    cb(nchw64, {
        add_pass<FuseConvBiasNonlinPass>();
        add_pass<PaddingChannelPass>();
        add_pass<FuseConvBiasZPass>();
        add_pass(EnableNCHW64Pass::make_nchw64_converter());
        add_pass<ShuffleShuffleRemovePass>();
        add_pass<RemoveRedundantTypeCvtPass>();
        add_pass(FuseNCHW4Int8Preprocess::make());
        add_pass<FuseWarpPerspectiveDimshufflePass>();
795
#if CUDA_VERSION >= 10020
796
        add_pass<FoldingConvBiasDimshufflePass>();
797
#endif
798
    });
799

800 801
    cb(fuse_conv_bias_nonlinearity, { add_pass<FuseConvBiasNonlinPass>(); });
    cb(fuse_conv_bias_with_z, {
802 803
        add_pass<FuseConvBiasNonlinPass>();
        add_pass<FuseConvBiasZPass>();
804 805 806 807
    });

#undef cb

808 809
    if (need_param_fuse) {
        add_pass<ParamFusePass>();
810
    }
811
    return *this;
812 813
}

814 815 816 817 818 819 820 821 822
const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
        const GraphTuningOptions& options) {
    bool need_param_fuse = false;

#define cb(_options, _passes)           \
    if (options.has_set_##_options()) { \
        _passes need_param_fuse = true; \
    }

823
    using Target = GraphTuningOptions::Target;
824 825
    cb(layout_transform, {
        add_pass<FuseConvBiasNonlinPass>();
826 827
        if (options.target == Target::CUDA)
            add_pass<FuseConvBiasZPass>();
M
Megvii Engine Team 已提交
828
        add_pass(LayoutTransformPass::make(options.target));
829
        add_pass<ShuffleShuffleRemovePass>();
830 831 832
        if (options.target == Target::CUDA) {
            add_pass(FuseNCHW4Int8Preprocess::make());
            add_pass<FuseWarpPerspectiveDimshufflePass>();
833
#if CUDA_VERSION >= 10020
834 835
            add_pass<FoldingConvBiasDimshufflePass>();
            add_pass<FoldingConvBiasTypecvtPass>();
836
#endif
837
        }
838 839 840 841 842 843 844 845 846
    });
#undef cb

    if (need_param_fuse) {
        add_pass<ParamFusePass>();
    }
    return *this;
}

847 848
/* ================ ConstVarPropogateBase ================ */

M
Megvii Engine Team 已提交
849
ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(OperatorNodeBase* opr) {
850
    using ProfFlag = OperatorNodeBase::NodeProp::Flag;
851
    auto&& info = m_oprinfo[opr];
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867
    if (info.processed)
        return info.result;
    info.processed = true;

#if MGB_ENABLE_JSON
    (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(false);
#endif

    AddOprResult ret{false, false, false};
    auto make_ret = [&ret, &info]() {
        info.result = ret;
        return ret;
    };

    if (is_const_var(m_const_var_type, opr)) {
        auto sz = var_mem_size(opr->output(0));
M
Megvii Engine Team 已提交
868 869
        mgb_assert(
                sz || opr->output(0)->contain_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE));
870 871 872 873 874 875 876 877
        info.is_const = true;
        info.max_size = sz;
        return make_ret();
    }

    if (opr->input().empty())
        return make_ret();

M
Megvii Engine Team 已提交
878 879
    if (opr->node_prop().contain(
                ProfFlag::FORCE_UPDATE_INPUT_VAR | ProfFlag::IMPURE_FUNC)) {
880 881 882 883 884
        return make_ret();
    }

    size_t max_input_size = 0;
    ret.all_const_inp = true;
885
    for (auto i : opr->input()) {
886 887 888 889 890 891 892
        auto io = i->owner_opr();
        auto iter = m_oprinfo.find(io);
        if (iter == m_oprinfo.end()) {
            add_opr(io);
            iter = m_oprinfo.find(io);
            mgb_assert(iter != m_oprinfo.end());
        }
893
        auto&& src = iter->second;
894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
        if (src.is_const) {
            update_max(max_input_size, src.max_size);
            ret.has_const_inp = true;
            if (!is_const_var(m_const_var_type, i->owner_opr())) {
                ret.has_midconst_inp = true;
            }
        } else {
            ret.all_const_inp = false;
        }
    }
    if (ret.all_const_inp) {
#if MGB_ENABLE_JSON
        (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(true);
#endif
        info.max_size = max_input_size;
        info.is_const = true;
    }
    return make_ret();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}