utility.cpp 28.8 KB
Newer Older
1 2 3 4
/**
 * \file imperative/src/impl/ops/utility.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include <atomic>
13
#include <deque>
14

M
Megvii Engine Team 已提交
15 16
#include "megbrain/imperative/graph_cache.h"
#include "megbrain/imperative/opr_utility.h"
17
#include "megbrain/imperative/ops/autogen.h"
18
#include "megbrain/imperative/ops/opr_attr.h"
M
Megvii Engine Team 已提交
19
#include "megbrain/imperative/ops/utility.h"
20
#include "megbrain/imperative/resource_manager.h"
21
#include "megbrain/imperative/subgraph_detail.h"
M
Megvii Engine Team 已提交
22
#include "megbrain/opr/io.h"
23 24
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
M
Megvii Engine Team 已提交
25
#include "megbrain/opr/utility.h"
26

27 28 29 30 31 32 33
#if MGB_JIT
#include "megbrain/jit/executor_opr.h"
#endif

#include "../event_pool.h"
#include "../op_trait.h"

34 35 36
namespace mgb::imperative {

MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
37
OP_TRAIT_REG(GenericPyOp, GenericPyOp).fallback();
38

M
Megvii Engine Team 已提交
39 40 41 42 43
namespace {
namespace fastpathcopy {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    return inputs;
}
44

45 46 47 48 49 50 51 52 53 54 55
auto make_backward_graph(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    Subgraph graph;
    graph.inputs = {1, 2, 3};
    graph.outputs = {3};
    graph.exprs = {};
    return EncodedSubgraph::make(graph);
}

M
Megvii Engine Team 已提交
56 57
OP_TRAIT_REG(FastpathCopy, FastpathCopy)
        .apply_on_var_node(apply_on_var_node)
58
        .make_backward_graph(make_backward_graph)
M
Megvii Engine Team 已提交
59 60 61
        .fallback();
}  // namespace fastpathcopy
}  // namespace
62

M
Megvii Engine Team 已提交
63 64 65
namespace {
namespace shape_infer {
auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) {
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    auto& op = def.cast_final_safe<ShapeInfer>();
    size_t nr_inputs = inputs.size();
    mgb_assert(nr_inputs > 0, "no inputs for ShapeInfer");
    SmallVector<LogicalTensorDesc> input_descs;
    for (size_t i = 0; i < nr_inputs; ++i) {
        auto input = inputs[i]->get_value();
        TensorLayout layout;
        layout.ndim = input.shape(0);
        for (size_t i = 0; i < layout.ndim; ++i) {
            layout[i] = input.ptr<int32_t>()[i];
        }
        layout.dtype = op.dtypes[i];
        layout.init_contiguous_stride();
        input_descs.push_back({layout, op.devices[i]});
    }
M
Megvii Engine Team 已提交
81 82
    auto [output_descs, valid] =
            OpDef::infer_output_attrs_fallible(*op.op, input_descs);
83 84
    mgb_assert(valid, "shape inference incomplete");
    SmallVector<TensorPtr> outputs;
M
Megvii Engine Team 已提交
85 86 87
    for (auto&& output_desc : output_descs) {
        HostTensorND shape_tensor{
                output_desc.comp_node, {output_desc.layout.ndim}, dtype::Int32()};
88 89 90 91 92 93 94 95
        for (size_t i = 0; i < output_desc.layout.ndim; ++i) {
            shape_tensor.ptr<int32_t>()[i] = output_desc.layout[i];
        }
        auto output = Tensor::make(shape_tensor);
        outputs.push_back(output);
    }
    return outputs;
}
M
Megvii Engine Team 已提交
96
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
97 98 99 100 101
    auto& op = def.cast_final_safe<ShapeInfer>();
    size_t nr_inputs = inputs.size();
    VarNodeArray input_values, outputs;
    mgb_assert(nr_inputs > 0, "no inputs for ShapeInfer");
    for (size_t i = 0; i < nr_inputs; ++i) {
M
Megvii Engine Team 已提交
102 103
        auto input_value =
                opr::Alloc::make(SymbolVar(inputs[i]), op.dtypes[i], {op.devices[i]});
104 105 106
        input_values.push_back(input_value.node());
    }
    auto output_values = OpDef::apply_on_var_node(*op.op, input_values);
M
Megvii Engine Team 已提交
107
    for (auto&& output_value : output_values) {
108 109 110 111 112 113
        outputs.push_back(opr::GetVarShape::make(output_value).node());
    }
    return outputs;
}

auto infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
114
        const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
115 116 117
    auto& op = def.cast_final_safe<ShapeInfer>();
    SmallVector<LogicalTensorDesc> input_shape_descs;
    size_t nr_inputs = op.devices.size();
M
Megvii Engine Team 已提交
118 119 120
    mgb_assert(
            op.dtypes.size() == nr_inputs,
            "number of input devices and dtypes mismatch");
121 122 123 124 125 126 127
    for (size_t i = 0; i < nr_inputs; ++i) {
        LogicalTensorDesc input_shape_desc;
        input_shape_desc.comp_node = op.devices[i];
        input_shape_desc.layout.ndim = 0;
        input_shape_desc.layout.dtype = op.dtypes[i];
        input_shape_descs.push_back(input_shape_desc);
    }
M
Megvii Engine Team 已提交
128 129
    auto [output_shape_descs, _] =
            OpDef::infer_output_attrs_fallible(*op.op, input_shape_descs);
130
    SmallVector<LogicalTensorDesc> output_descs;
M
Megvii Engine Team 已提交
131
    for (auto&& output_shape_desc : output_shape_descs) {
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
        LogicalTensorDesc output_desc;
        output_desc.comp_node = output_shape_desc.comp_node;
        output_desc.layout.ndim = 1;
        output_desc.layout.dtype = dtype::Int32();
        output_descs.push_back(output_desc);
    }
    return std::make_tuple(output_descs, false);
}

auto props(const OpDef& def) {
    auto& op = def.cast_final_safe<ShapeInfer>();
    return OpDef::props(*op.op);
}

auto make_name(const OpDef& def) {
    auto& op = def.cast_final_safe<ShapeInfer>();
    MGB_MARK_USED_VAR(op);
    return ssprintf("ShapeInfer[%s]", op.op->make_name().c_str());
}

auto hash(const OpDef& def) {
    auto& op = def.cast_final_safe<ShapeInfer>();
    return op.op->hash();
}

auto is_same_st(const OpDef& def, const OpDef& another) {
    if (!another.same_type<ShapeInfer>()) {
        return false;
    }
    auto& lhs = def.cast_final_safe<ShapeInfer>();
    auto& rhs = another.cast_final_safe<ShapeInfer>();
    if (!lhs.op->is_same(*rhs.op)) {
        return false;
    }
M
Megvii Engine Team 已提交
166
    return std::tie(lhs.devices, lhs.dtypes) == std::tie(rhs.devices, rhs.dtypes);
167 168
}

M
Megvii Engine Team 已提交
169 170 171 172 173 174 175 176 177 178 179
OP_TRAIT_REG(ShapeInfer, ShapeInfer)
        .apply_on_var_node(apply_on_var_node)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .make_name(make_name)
        .props(props)
        .hash(hash)
        .is_same_st(is_same_st)
        .fallback();
}  // namespace shape_infer
}  // namespace
180 181 182

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShapeInfer);

M
Megvii Engine Team 已提交
183 184 185
namespace {
namespace identity {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
186 187 188 189 190 191
    auto&& op = def.cast_final_safe<Identity>();
    mgb_assert(inputs.size() == 1);
    OperatorNodeConfig config{op.make_name()};
    return opr::Identity::make(inputs[0], config);
}

M
Megvii Engine Team 已提交
192
auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) {
193 194 195
    return SmallVector<TensorPtr>{inputs[0]};
}
OP_TRAIT_REG(Identity, Identity)
M
Megvii Engine Team 已提交
196 197 198 199 200
        .apply_on_var_node(apply_on_var_node)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();
}  // namespace identity
}  // namespace
201

M
Megvii Engine Team 已提交
202 203
namespace {
namespace subgraph {
204

M
Megvii Engine Team 已提交
205 206
EncodedSubgraph make_forward_graph(
        const OpDef& def, SmallVector<LogicalTensorDesc> inputs) {
M
Megvii Engine Team 已提交
207
    return EncodedSubgraph::make(*def.cast_final_safe<SubgraphOp>().graph);
208 209
}

M
Megvii Engine Team 已提交
210
EncodedSubgraph make_backward_graph(
M
Megvii Engine Team 已提交
211
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
212 213 214 215 216 217 218 219 220
        const SmallVector<bool>& input_requires_grad,
        SmallVector<bool> output_has_grad) {
    auto& op = def.cast_final_safe<SubgraphOp>();
    mgb_assert(output_has_grad.size() == op.output_grad_mask.size());
    for (size_t i = 0; i < output_has_grad.size(); ++i) {
        if (!op.output_grad_mask[i]) {
            output_has_grad[i] = false;
        }
    }
M
Megvii Engine Team 已提交
221 222
    auto bgraph = subgraph_detail::make_backward_graph(
            def, inputs, input_requires_grad, output_has_grad);
M
Megvii Engine Team 已提交
223
    return EncodedSubgraph::make_single(
M
Megvii Engine Team 已提交
224 225
            SubgraphOp::make(
                    op.name + "Grad", std::make_shared<Subgraph>(bgraph.graph)),
226
            bgraph.input_mask, bgraph.output_mask);
227 228 229 230 231
}

std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
    auto& op = def.cast_final_safe<SubgraphOp>();
    return {
M
Megvii Engine Team 已提交
232 233 234 235
            {"name", op.name},
            {"inputs", mgb::imperative::to_string(op.graph->inputs)},
            {"exprs", mgb::imperative::to_string(op.graph->exprs)},
            {"outputs", mgb::imperative::to_string(op.graph->outputs)},
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
    };
}

std::string make_name(const OpDef& def) {
    auto& op = def.cast_final_safe<SubgraphOp>();
    if (op.name.empty()) {
        return "SubgraphOp";
    } else {
        return op.name;
    }
}

auto hash(const OpDef& def) {
    auto& op = def.cast_final_safe<SubgraphOp>();
    if (!op.graph_key) {
M
Megvii Engine Team 已提交
251
        return (size_t) reinterpret_cast<uintptr_t>(op.graph.get());
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    }
    return op.graph_key->hash();
}

auto is_same_st(const OpDef& def, const OpDef& another) {
    if (!another.same_type<SubgraphOp>()) {
        return false;
    }
    auto& lhs = def.cast_final_safe<SubgraphOp>();
    auto& rhs = another.cast_final_safe<SubgraphOp>();
    auto has_graph_key = bool(lhs.graph_key);
    bool graph_same = false;
    if (has_graph_key) {
        graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key);
    } else {
267
        graph_same = !rhs.graph_key && lhs.graph.get() == rhs.graph.get();
268 269 270 271 272
    }
    return graph_same;
}

OP_TRAIT_REG(SubgraphOp, SubgraphOp)
M
Megvii Engine Team 已提交
273 274 275 276 277 278 279
        .make_forward_graph(make_forward_graph)
        .make_backward_graph(make_backward_graph)
        .props(props)
        .make_name(make_name)
        .hash(hash)
        .is_same_st(is_same_st)
        .fallback();
280

M
Megvii Engine Team 已提交
281 282
}  // namespace subgraph
}  // namespace
283

M
Megvii Engine Team 已提交
284 285
namespace {
namespace compiled_op {
286

M
Megvii Engine Team 已提交
287
struct DeviceMemoryAllocatorImpl : cg::DeviceMemoryAllocator {
288
    std::shared_ptr<OpDef> current_op;
M
Megvii Engine Team 已提交
289 290
    void alloc_static(
            ComputingGraph* graph, DeviceTensorStorage& dest, size_t size) override {
291 292 293 294 295 296 297 298 299
        mgb_assert(0, "alloc_static is not allowed in CompiledOp");
    }
    void alloc_dynamic(VarNode* var, DeviceTensorStorage& dest, size_t size) override {
        auto comp_node = var->comp_node();
        auto storage = current_op->allocate(comp_node, size);
        dest.reset(comp_node, size, storage);
    }
};

300 301 302 303 304 305
enum class HolderKind {
    ShapeInfer,
    Execute,
};

template <HolderKind Kind>
306
struct ComputingGraphHolder {
307 308 309 310 311
    struct Input {
        std::shared_ptr<DeviceTensorND> device_value;
        std::shared_ptr<HostTensorND> host_value;
        std::shared_ptr<HostTensorND> host_shape;
    };
312 313
    std::shared_ptr<ComputingGraph> graph;
    std::unique_ptr<cg::AsyncExecutable> executable;
314 315 316 317
    SmallVector<Input> inputs;
    SmallVector<std::shared_ptr<DeviceTensorND>> device_outputs;
    SmallVector<VarNode*> input_vars;
    SmallVector<VarNode*> output_vars;
318
    std::shared_ptr<DeviceMemoryAllocatorImpl> allocator;
319
    SmallVector<std::shared_ptr<CompNode::Event>> events;
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
    std::unique_ptr<cg::static_infer::StaticInferUpdater> updater;

    void initialize(
            const CompiledOp& op, const SmallVector<LogicalTensorDesc>& input_descs) {
        allocator = std::make_shared<DeviceMemoryAllocatorImpl>();
        graph = ComputingGraph::make();
        graph->options().force_dynamic_alloc = true;
        graph->options().async_exec_level = 0;
        graph->options().graph_opt_level = op.gopt_level;
        graph->options().enable_var_mem_defragment = false;
        graph->options().comp_seq_sync_device = false;
        // set allocator for DTR support
        graph->set_device_memory_allocator(allocator);
        if constexpr (Kind == HolderKind::ShapeInfer) {
            updater = cg::static_infer::StaticInferUpdater::make();
        }
        for (auto&& desc : input_descs) {
            Input input;
            VarNode* input_var = nullptr;
            if constexpr (Kind == HolderKind::Execute) {
                input.device_value = std::make_shared<DeviceTensorND>();
                input.device_value->dtype(desc.layout.dtype);
                input.device_value->comp_node(desc.comp_node);
                input.device_value->resize(desc.layout);
                auto callback = [value = input.device_value] { return *value; };
                if (!desc.value.empty()) {
                    input.host_value = std::make_shared<HostTensorND>();
                    input.host_value->dtype(desc.layout.dtype);
                    input.host_value->comp_node(desc.comp_node);
                }
                input_var = opr::MutableTensor::make(
                                    *graph, input.device_value, input.host_value, {})
                                    .node();
                // input_var = opr::VolatileSharedDeviceTensor::make(*graph,
                // input.device_value).node();
            } else if constexpr (Kind == HolderKind::ShapeInfer) {
                if (desc.value.empty()) {
                    input.host_shape = std::make_shared<HostTensorND>();
                    input.host_shape->dtype(dtype::Int32());
                    input.host_shape->comp_node(desc.comp_node);
                    auto input_shape_var =
                            opr::Host2DeviceCopy::make(*graph, input.host_shape);
                    input_var =
                            opr::Alloc::make(input_shape_var, desc.layout.dtype).node();
                } else {
                    input.host_value = std::make_shared<HostTensorND>();
                    input.host_value->dtype(desc.layout.dtype);
                    input.host_value->comp_node(desc.comp_node);
                    input_var =
                            opr::Host2DeviceCopy::make(*graph, input.host_value).node();
                }
            } else {
                static_assert((Kind != Kind), "unknown holder kind");
            }
            input_vars.push_back(input_var);
            inputs.push_back(input);
        }
        // forward to inner op
        output_vars = OpDef::apply_on_var_node(*op.op, input_vars);
        ComputingGraph::OutputSpec output_spec;
        CompNode::UnorderedSet comp_nodes;
        for (auto&& output_var : output_vars) {
            using namespace cg::static_infer;
            auto output_ptr = std::make_shared<DeviceTensorND>();
            auto callback = [output_ptr](DeviceTensorND output) {
                output_ptr->reset(output.storage(), output.layout());
                output = {};
            };
            if constexpr (Kind == HolderKind::ShapeInfer) {
                output_spec.push_back({output_var, callback});
                auto it = graph->static_infer_manager().get_infer_type(output_var);
                if (it.shape == InferType::RT_STATIC) {
                    updater->add_dest({output_var, DepType::SHAPE});
                }
                if (it.value == InferType::RT_STATIC) {
                    updater->add_dest({output_var, DepType::VALUE});
                }
            } else {
                auto output_callback_var =
                        opr::OutputCallback::make({callback}, output_var);
                output_spec.push_back({output_callback_var, {}});
            }
            device_outputs.push_back(output_ptr);
        }
        executable = graph->compile(output_spec);
        executable->iter_opr_seq([&](cg::OperatorNodeBase* opr) -> bool {
            for (auto&& output : opr->output()) {
                comp_nodes.insert(output->comp_node());
            }
            return true;
        });
        for (auto&& comp_node : comp_nodes) {
412
            events.push_back(EventPool::without_timer().alloc_shared(comp_node));
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
            events.back()->record();
        }
    }

    template <
            HolderKind ThisKind = Kind,
            typename = std::enable_if_t<ThisKind == HolderKind::Execute>>
    SmallVector<TensorPtr> apply_on_physical_tensor(
            const OpDef& def, const SmallVector<LogicalTensorDesc> input_descs,
            const SmallVector<TensorPtr>& input_tensors) {
        // wait for last execution
        executable->wait();
        size_t nr_inputs = inputs.size();
        for (size_t i = 0; i < nr_inputs; ++i) {
            auto input_dev_tensor = input_tensors[i]->dev_tensor();
            inputs[i].device_value->reset(
                    input_dev_tensor.storage(), input_dev_tensor.layout());
            if (inputs[i].host_value) {
                inputs[i].host_value->copy_from(input_descs[i].value);
            }
        }
        allocator->current_op = const_cast<OpDef&>(def).shared_from_this();
        executable->execute();
        for (auto&& event : events) {
            event->record();
        }
        SmallVector<TensorPtr> outputs_tensors;
        for (auto input : inputs) {
            *input.device_value = {};
            if (input.host_value) {
                *input.host_value = {};
            }
        }
        for (auto output_nd : device_outputs) {
            outputs_tensors.push_back(Tensor::make(*output_nd));
            *output_nd = {};
        }
        executable->clear_device_memory();
        allocator->current_op = nullptr;
        return outputs_tensors;
    }

    template <
            HolderKind ThisKind = Kind,
            typename = std::enable_if_t<ThisKind == HolderKind::ShapeInfer>>
    std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
            const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
        executable->wait();
        size_t nr_inputs = input_vars.size(), nr_outputs = output_vars.size();
        SmallVector<LogicalTensorDesc> output_descs(nr_outputs);
        for (size_t i = 0; i < nr_inputs; ++i) {
            if (inputs[i].host_shape) {
                DeviceTensorND input_shape_device_nd;
                cg::copy_shape_to_tensor_value(
                        input_shape_device_nd, input_descs[i].layout);
                inputs[i].host_shape->copy_from(input_shape_device_nd);
                mgb_assert(input_descs[i].layout.ndim, "ndim == 0");
            } else if (inputs[i].host_value) {
                inputs[i].host_value->copy_from(input_descs[i].value);
            }
        }
        updater->update();
        bool validated = true;
        for (size_t i = 0; i < nr_outputs; ++i) {
            auto infer_type =
                    graph->static_infer_manager().get_infer_type(output_vars[i]);
            const TensorShape* output_shape = nullptr;
            const DeviceTensorND* output_value = nullptr;
            auto& desc = output_descs[i];
            if (infer_type.shape != cg::static_infer::InferType::NO_DESC) {
                output_shape = graph->static_infer_manager().infer_shape_fallible(
                        output_vars[i]);
            }
            if (infer_type.value != cg::static_infer::InferType::NO_DESC) {
                output_value = graph->static_infer_manager().infer_value_fallible(
                        output_vars[i]);
            }
            if (output_shape && output_value) {
                mgb_assert(
                        output_shape->eq_shape(output_value->shape()),
                        "shape infer result mismatch, %s vs %s",
                        output_shape->to_string().c_str(),
                        output_value->shape().to_string().c_str());
            }
            if (output_shape) {
                ((TensorShape&)desc.layout) = *output_shape;
            }
            if (output_value) {
                ((TensorShape&)desc.layout) = output_value->shape();
                desc.value = *output_value;
            }
            desc.layout.dtype = output_vars[i]->dtype();
            desc.comp_node = output_vars[i]->comp_node();
            if (!desc.layout.ndim) {
                validated = false;
            }
            desc.layout.init_contiguous_stride();
        }
        return {output_descs, validated};
    }
513 514
};

515 516
static std::atomic<size_t> nr_cg_cache = 0;

517 518 519 520
template <HolderKind Kind>
ComputingGraphHolder<Kind>& get_computing_graph(
        std::shared_ptr<OpDef> compiled_op,
        const SmallVector<LogicalTensorDesc>& descs) {
M
Megvii Engine Team 已提交
521
    using ComputingGraphHolderCache =
522
            OpMethResultCache<std::deque<std::unique_ptr<ComputingGraphHolder<Kind>>>>;
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
    thread_local auto& cache = ([]() -> auto& {
        mgb_assert(
                nr_cg_cache++ < 5,
                "using subgraph in too many threads, this causes resource leakage");
#if MGB_CUDA && defined(WIN32)
        // FIXME: Create as global to skip resource finalize and windows with cuda
        // doesn't cleanup global resources
        return *ResourceManager::create_global<ComputingGraphHolderCache>();
#else
        // Otherwise this should be local because compnode may be unusable when global
        // resource finalizing.
        // For example, CpuCompNode.sync hang on because underlying thread died
        return *ResourceManager::create_local<ComputingGraphHolderCache>();
#endif
    })();
538
    thread_local size_t nr_cg_holders = 0;
539
    typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
540
    auto& cg_holder_queue = cache[cache_key];
541
    std::unique_ptr<ComputingGraphHolder<Kind>> holder;
M
Megvii Engine Team 已提交
542
    if (!cg_holder_queue.empty()) {
543 544 545
        // pick one
        std::swap(cg_holder_queue.front(), holder);
        // check all events finished
M
Megvii Engine Team 已提交
546
        for (auto&& event : holder->events) {
547
            if (!event->finished()) {
M
Megvii Engine Team 已提交
548 549
                bool queue_limited =
                        event->comp_node().contain_flag(CompNode::Flag::QUEUE_LIMITED);
550 551 552 553 554 555
                bool many_graph = cg_holder_queue.size() > 10;
                if (queue_limited || !many_graph) {
                    std::swap(cg_holder_queue.front(), holder);
                    break;
                } else {
                    // graph limit
M
Megvii Engine Team 已提交
556 557 558
                    mgb_log_debug(
                            "computing graph limit for compiled op exceeded, waiting "
                            "for prev graph");
559 560
                    event->host_wait();
                }
561 562
            } else {
                event->host_wait();
563 564 565
            }
        }
        if (holder) {
566
            cg_holder_queue.pop_front();
567 568 569 570
        }
    }
    if (!holder) {
        // create new computing graph
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
        auto create_holder = [&] {
            auto holder = std::make_unique<ComputingGraphHolder<Kind>>();
            auto& cg_holder = *holder;
            cg_holder.initialize(compiled_op->cast_final_safe<CompiledOp>(), descs);
            nr_cg_holders++;
            mgb_log_debug(
                    "add new computing graph for compiled op, now %zu graphs",
                    nr_cg_holders);
            return holder;
        };
        size_t nr_graphs = std::max(cg_holder_queue.size(), (size_t)1);
        for (size_t i = 1; i < nr_graphs; ++i) {
            cg_holder_queue.push_front(create_holder());
        }
        holder = create_holder();
586
    }
587
    cg_holder_queue.push_back(std::move(holder));
588
    return *cg_holder_queue.back();
589 590
}

M
Megvii Engine Team 已提交
591
auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) {
592
    SmallVector<LogicalTensorDesc> input_descs;
M
Megvii Engine Team 已提交
593
    for (auto&& input : inputs) {
594
        input_descs.push_back({input->layout(), input->comp_node()});
595 596 597 598 599 600
        if (auto* host_value = input->try_get_value()) {
            if (host_value->layout().total_nr_elems() <=
                MEGDNN_MAX_NDIM) {  // infer small tensor
                input_descs.back().value = host_value->proxy_to_default_cpu();
            }
        }
601 602
    }
    auto shared_def = const_cast<OpDef&>(def).shared_from_this();
603 604
    auto& cg_holder = get_computing_graph<HolderKind::Execute>(shared_def, input_descs);
    return cg_holder.apply_on_physical_tensor(def, input_descs, inputs);
605
}
606

M
Megvii Engine Team 已提交
607
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
608 609 610
    auto& op = def.cast_final_safe<CompiledOp>();
    op.op->set_scope(op.scope());
    return OpDef::apply_on_var_node(*op.op, inputs);
611 612 613
}

auto infer_output_attrs_fallible(
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635
        const OpDef& def, SmallVector<LogicalTensorDesc> input_descs) {
    bool shape_all_valid = true;
    for (auto&& input_desc : input_descs) {
        if (!input_desc.layout.ndim) {
            shape_all_valid = false;
            break;
        }
    }
    if (!shape_all_valid) {
        return OpDef::infer_output_attrs_fallible(
                *def.cast_final_safe<CompiledOp>().op, input_descs);
    }
    auto shared_def = const_cast<OpDef&>(def).shared_from_this();
    for (auto& input_desc : input_descs) {
        if (input_desc.layout.total_nr_elems() >
            MEGDNN_MAX_NDIM) {  // skip large tensor
            input_desc.value = {};
        }
    }
    auto& cg_holder =
            get_computing_graph<HolderKind::ShapeInfer>(shared_def, input_descs);
    return cg_holder.infer_output_attrs_fallible(def, input_descs);
636 637 638 639 640 641 642 643 644 645 646 647
}

auto props(const OpDef& def) {
    return OpDef::props(*def.cast_final_safe<CompiledOp>().op);
}

auto make_name(const OpDef& def) {
    auto& op = def.cast_final_safe<CompiledOp>();
    MGB_MARK_USED_VAR(op);
    return ssprintf("CompiledOp[%s]", op.op->make_name().c_str());
}

M
Megvii Engine Team 已提交
648
EncodedSubgraph make_backward_graph(
M
Megvii Engine Team 已提交
649
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
650 651 652
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    auto& op = def.cast_final_safe<CompiledOp>();
M
Megvii Engine Team 已提交
653 654
    auto backward_graph = OpDef::make_backward_graph(
            *op.op, inputs, input_requires_grad, output_has_grad);
655
    auto name = def.trait()->make_name(def);
656 657
    std::shared_ptr<OpDef> bgraph_op =
            SubgraphOp::wrap(name + "Grad", backward_graph.graph);
658
    auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level);
M
Megvii Engine Team 已提交
659 660
    auto encoded_graph = EncodedSubgraph::make_single(
            compiled_op, backward_graph.input_mask, backward_graph.output_mask);
661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
    return encoded_graph;
}

auto hash(const OpDef& def) {
    auto& op = def.cast_final_safe<CompiledOp>();
    return mgb::hash_pair_combine(op.op->hash(), op.gopt_level);
}

auto is_same_st(const OpDef& def, const OpDef& another) {
    if (!another.same_type<CompiledOp>()) {
        return false;
    }
    auto& lhs = def.cast_final_safe<CompiledOp>();
    auto& rhs = another.cast_final_safe<CompiledOp>();
    return lhs.op->is_same(*rhs.op) && lhs.gopt_level == rhs.gopt_level;
}

OP_TRAIT_REG(CompiledOp, CompiledOp)
M
Megvii Engine Team 已提交
679 680 681 682 683 684 685 686 687 688 689
        .apply_on_var_node(apply_on_var_node)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .make_backward_graph(make_backward_graph)
        .make_name(make_name)
        .props(props)
        .hash(hash)
        .is_same_st(is_same_st)
        .fallback();
}  // namespace compiled_op
}  // namespace
690

M
Megvii Engine Team 已提交
691 692 693 694 695 696 697 698 699 700 701 702 703
namespace {
namespace jit_fusion {

static thread_local bool tm_enabled = true;

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto& op = def.cast_final_safe<JITFusionOp>();
    op.op->set_scope(op.scope());
    auto outputs = OpDef::apply_on_var_node(*op.op, inputs);
    if (!tm_enabled) {
        // skip for dump (JITExecutor can not be dumped)
        return outputs;
    }
704
#if MGB_JIT
M
Megvii Engine Team 已提交
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720
    for (auto& output : outputs) {
        jit::InternalGraphGenerator igg{output->owner_opr()};
        std::vector<cg::OperatorNodeBase*> reverse_order;
        cg::DepOprIter iter{
                [&](cg::OperatorNodeBase* opr) { reverse_order.push_back(opr); }};
        for (auto&& input : inputs) {
            iter.set_visited(input->owner_opr());
        }
        iter.add(output->owner_opr());
        std::reverse(reverse_order.begin(), reverse_order.end());
        for (auto&& opr : reverse_order) {
            igg.add_opr(opr);
        }
        auto ig = igg.generate();
        output = jit::JITExecutor::make(ig, igg.orig_inps()).node();
    }
721 722 723
#else
    mgb_assert(false, "MGB_WITH_JIT was disabled");
#endif
M
Megvii Engine Team 已提交
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773
    return outputs;
}

auto infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
    return OpDef::infer_output_attrs_fallible(
            *def.cast_final_safe<JITFusionOp>().op, input_descs);
}

auto props(const OpDef& def) {
    return OpDef::props(*def.cast_final_safe<JITFusionOp>().op);
}

auto hash(const OpDef& def) {
    return def.cast_final_safe<JITFusionOp>().op->hash();
}

auto is_samt_st(const OpDef& def, const OpDef& another) {
    if (!another.same_type<JITFusionOp>()) {
        return false;
    }
    auto& lhs = def.cast_final_safe<JITFusionOp>();
    auto& rhs = another.cast_final_safe<JITFusionOp>();
    return lhs.op->is_same(*rhs.op);
}

EncodedSubgraph make_backward_graph(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    return {};
}

OP_TRAIT_REG(JITFusionOp, JITFusionOp)
        .apply_on_var_node(apply_on_var_node)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .props(props)
        .hash(hash)
        .is_same_st(is_samt_st)
        .make_backward_graph(make_backward_graph)
        .fallback();

}  // namespace jit_fusion
}  // namespace

bool JITFusionOp::set_enabled(bool enabled) {
    std::swap(enabled, jit_fusion::tm_enabled);
    return enabled;
}

774 775
MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey);

776 777 778 779 780 781
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp);

MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey);

MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompiledOp);

M
Megvii Engine Team 已提交
782 783
MGB_DYN_TYPE_OBJ_FINAL_IMPL(JITFusionOp);

M
Megvii Engine Team 已提交
784
}  // namespace mgb::imperative