interpreter_impl.cpp 54.8 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21
#include "megbrain/imperative/ops/utility.h"
22
#include "megbrain/imperative/utils/stats.h"
23 24
#include "megbrain/imperative/utils/to_string.h"

25
#include "../blob_manager_impl.h"
26 27 28
#include "../event_pool.h"
#include "../op_trait.h"

29 30 31 32 33
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

34
namespace {
M
Megvii Engine Team 已提交
35 36 37 38 39 40 41 42
auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
    SmallVector<uint64_t> tid;
    for (auto* ptinfo : tinfo) {
        tid.push_back(ptinfo->id);
    }
    return tid;
};
}  // namespace
43

44
namespace mgb {
M
Megvii Engine Team 已提交
45
using namespace profiler;
46 47
}

48 49 50 51 52
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
53 54 55 56 57 58 59

namespace mgb {

/**
 * USAGE
 *
 *   header:
60
 *     namespace mgb { void imperative_log_profile(const char* message); }
61 62 63 64 65
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
66
SYMBOL_EXPORT
67
void imperative_log_profile_begin(const char* message) {
68
    MGB_RECORD_EVENT(CustomEvent, std::string{message});
69 70
}

71
SYMBOL_EXPORT
72
void imperative_log_profile_end(const char* message) {
73
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
74 75
}

76
SYMBOL_EXPORT
M
Megvii Engine Team 已提交
77
void imperative_log_profile(const char* message) {
78 79 80 81
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

82 83 84 85
SYMBOL_EXPORT
void imperative_log_profile_begin(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
    MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
M
Megvii Engine Team 已提交
86 87
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
88 89 90 91 92
}

SYMBOL_EXPORT
void imperative_log_profile_end(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
M
Megvii Engine Team 已提交
93 94
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
95 96 97
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
}

M
Megvii Engine Team 已提交
98
}  // namespace mgb
99

100 101 102 103 104 105 106 107 108 109 110 111 112 113
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

114 115 116
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
117
    auto custom_allocator = [&](CompNode device, size_t size) {
118 119 120
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
121 122
    };
    OpDef::set_allocator(custom_allocator);
123 124
}

125
// Do not use m_xxx_state directly
126 127 128
#define m_channel_state
#define m_worker_state

129 130 131 132 133 134 135 136 137
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

138
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
139
    MGB_LOCK_GUARD(m_spin);
140
    mgb_assert(check_available(), "Channel already closed");
141 142 143 144 145
    std::optional<StackManager::Guard> guard;
    if (Profiler::is_profiling()) {
        auto& state = get_channel_state();
        guard.emplace("Put", &state.stack_manager);
    }
146
    auto info = put_impl(value, no_cache);
M
Megvii Engine Team 已提交
147
    return reinterpret_cast<Handle>(info);
148 149 150
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
151 152 153 154 155
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
156
    auto info = alloc();
157 158 159 160 161 162
    constexpr int size_threshold = TensorShape::MAX_NDIM;
    init(info, {value.layout(), value.comp_node()});
    if (value.layout().total_nr_elems() <= size_threshold) {
        info->h_value = value;
        info->desc.value = value.proxy_to_default_cpu();
    }
163 164 165 166 167 168 169 170 171 172
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Put{info, value, no_cache},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Put{info, value, no_cache},
        });
    }
173
    if (m_async_level == 0) {
174
        sync_impl();
175
        info->desc.comp_node.sync();
176 177
        auto err = info->desc.comp_node.check_async_error();
        mgb_assert(!err, "%s", err->what());
178
    }
179 180 181
    return info;
}

182
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
183
    MGB_LOCK_GUARD(m_spin);
184
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
185
    return reinterpret_cast<Handle>(put_impl(data, hvalue));
186
}
M
Megvii Engine Team 已提交
187 188
TensorInfo* ChannelImpl::put_impl(
        const DeviceTensorND& data, const HostTensorND& hvalue) {
189 190 191 192 193
    std::optional<StackManager::Guard> guard;
    if (Profiler::is_profiling()) {
        auto& state = get_channel_state();
        guard.emplace("Put", &state.stack_manager);
    }
M
Megvii Engine Team 已提交
194
    auto info = alloc();
195
    MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
196
    constexpr int size_threshold = TensorShape::MAX_NDIM;
197
    init(info, {data.layout(), data.comp_node()});
198 199 200
    if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
        info->desc.value = hvalue.proxy_to_default_cpu();
    }
201
    info->ptr = Tensor::make(data, hvalue);
M
Megvii Engine Team 已提交
202 203 204
    MGB_RECORD_EVENT(
            TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
            data.raw_ptr());
205
    info->status = TensorInfo::Produced;
206
    MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
M
Megvii Engine Team 已提交
207 208 209
    return info;
}

210
void ChannelImpl::del(Handle handle) {
211
    MGB_LOCK_GUARD(m_spin);
M
Megvii Engine Team 已提交
212
    if (!check_available()) {
213 214
        return;
    }
215 216 217 218
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
219 220 221
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
222 223 224 225 226 227 228 229 230 231
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Del{info},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Del{info},
        });
    }
232 233
}

234
void ChannelImpl::drop(Handle handle) {
235
    MGB_LOCK_GUARD(m_spin);
236
    mgb_assert(check_available(), "Channel already closed");
237 238
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
M
Megvii Engine Team 已提交
239 240
        mgb_assert(
                m_valid_handle.find(handle) != m_valid_handle.end(),
241
                "invalid handle: %p", handle);
242
        auto* info = reinterpret_cast<TensorInfo*>(handle);
243 244 245 246 247 248 249 250 251 252
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), Drop{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    Drop{info},
            });
        }
253 254 255
    }
}

256
void ChannelImpl::dispatch_default_cpu(
M
Megvii Engine Team 已提交
257
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
258 259
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
260
    auto& state = get_channel_state();
261

262 263 264 265
    std::optional<StackManager::Guard> guard;
    if (Profiler::is_profiling()) {
        guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
    }
266

M
Megvii Engine Team 已提交
267 268
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
269
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
270

271 272
    SmallVector<DeviceTensorND> input_tensornds;
    CompNode output_cn;
273 274
    {
        MGB_LOCK_GUARD(m_mutex);
275
        for (auto&& info : input_infos) {
276
            auto input_cn = info->desc.comp_node;
277
            if (!output_cn.valid()) {
278 279 280 281 282 283
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
M
Megvii Engine Team 已提交
284 285
                input_tensornds.emplace_back(
                        info->ptr->get_value().proxy_to_default_cpu());
286
            } else {
287
                // We assign h_value before drop ptr
288 289
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
290 291 292 293 294 295 296 297 298
            }
        }
    }

    SmallVector<DeviceTensorND> output_tensornds;
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
M
Megvii Engine Team 已提交
299 300
        output_tensornds.emplace_back(
                HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
301 302
    }

303
    uint64_t op_id = Profiler::next_id();
304

305 306 307 308 309 310 311 312 313
    if (op->trait()->apply_on_device_tensornd) {
        OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
    } else {
        // proxy to apply_on_physical_tensor
        SmallVector<TensorPtr> input_tensors;
        for (auto&& input_tensornd : input_tensornds) {
            input_tensors.push_back(Tensor::make(
                    input_tensornd, HostTensorND::make_proxy(input_tensornd)));
        }
314 315
        auto output_tensors = OpDef::apply_on_physical_tensor(
                *op, input_tensors, output_descs, validated);
316 317 318 319
        for (size_t i = 0; i < output_tensors.size(); ++i) {
            output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
        }
    }
320 321 322

    SmallVector<TensorInfo*> output_infos;
    for (auto&& tensornd : output_tensornds) {
M
Megvii Engine Team 已提交
323 324
        HostTensorND host_tensornd =
                HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
325
        // use `put` for consistency
326
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
327
        mgb_assert(info->desc.layout.ndim != 0);
328
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
329
        outputs->push_back(reinterpret_cast<Handle>(info));
330
    }
M
Megvii Engine Team 已提交
331
    auto op_info_getter = [op] {
332 333
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
M
Megvii Engine Team 已提交
334
        for (auto&& [key, value] : props) {
335 336 337 338
            op_info[key] = value;
        }
        return op_info;
    };
M
Megvii Engine Team 已提交
339
    MGB_RECORD_EVENT(
340 341 342
            OpDispatchEvent, op_id, guard.value().name(), op_info_getter,
            tinfo_to_tid(input_infos), tinfo_to_tid(output_infos),
            state.stack_manager.dump());
343
}
344

345
void ChannelImpl::dispatch_kernel(
M
Megvii Engine Team 已提交
346
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
347 348
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
349
    auto& state = get_channel_state();
350 351
    auto& options = state.options;

352 353 354 355
    std::optional<StackManager::Guard> guard;
    if (Profiler::is_profiling()) {
        guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
    }
356

M
Megvii Engine Team 已提交
357 358
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
359
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
360

361 362 363 364
    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());

    outputs->reserve(output_descs.size());
365 366
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
367
        auto info = alloc();
368
        init(info, std::move(desc));
369 370
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
371 372
            info->h_value = HostTensorND::make_proxy(info->desc.value)
                                    .proxy_to_comp_node(info->desc.comp_node);
373
        }
374
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
375
        outputs->push_back(reinterpret_cast<Handle>(info));
376
    }
377 378 379
    ApplyOp cmd{
            Profiler::next_id(), std::move(op), std::move(input_infos),
            std::move(output_infos), validated};
380
    if (Profiler::is_profiling()) {
381 382 383 384 385 386 387 388
        auto op_info_getter = [op = cmd.op] {
            std::unordered_map<std::string, std::string> op_info;
            auto props = OpDef::props(*op);
            for (auto&& [key, value] : props) {
                op_info[key] = value;
            }
            return op_info;
        };
389
        MGB_RECORD_EVENT(
390 391 392
                OpDispatchEvent, cmd.id, guard.value().name(), op_info_getter,
                tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs),
                state.stack_manager.dump());
393
        m_worker.add_task(
394
                {Profiler::next_id(), std::move(cmd),
395 396 397 398
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
399
                std::move(cmd),
400 401
        });
    }
402
    if (!validated && options.async_level == 1) {
403
        sync_impl();
404
    } else if (options.async_level == 0) {
405
        sync_impl();
406
        // check device error
407
        for (auto&& oup : *outputs) {
408 409
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
410 411
            auto err = info->ptr->comp_node().check_async_error();
            mgb_assert(!err, "%s", err->what());
412
        }
413
    }
414 415 416
}

SmallVector<Handle> ChannelImpl::apply_op(
M
Megvii Engine Team 已提交
417
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
418
    MGB_LOCK_GUARD(m_spin);
419
    mgb_assert(check_available(), "Channel already closed");
420 421 422 423 424 425 426 427 428 429 430
    auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
    if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
        size_t ndim = input->desc.layout.ndim;
        auto& gvs = op->cast_final_safe<GetVarShape>();
        if (gvs.axis == MEGDNN_MAX_NDIM) {
            HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
            DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
            cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
            return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
        }
    }
431 432 433 434
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
M
Megvii Engine Team 已提交
435
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
436
    auto& state = get_channel_state();
437
    for (auto i : inputs) {
M
Megvii Engine Team 已提交
438 439 440
        mgb_assert(
                m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
                i);
441 442 443 444
    }
    SmallVector<TensorInfo*> input_infos;
    SmallVector<LogicalTensorDesc> input_descs;
    {
445
        MGB_LOCK_GUARD(m_info_spin);
446 447
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
M
Megvii Engine Team 已提交
448 449 450
            mgb_assert(
                    !info->invalid,
                    "an input tensor is unusable due to previous error");
451 452 453 454 455 456
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
457
    DispatchMode dispatch_mode = state.options.enable_host_compute
M
Megvii Engine Team 已提交
458 459
                                       ? OpDef::decide_dispatch_mode(*op, input_descs)
                                       : DispatchMode::KERNEL;
460
    switch (dispatch_mode) {
461 462 463 464 465 466 467 468 469
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
470 471 472
    return outputs;
}

473
HostTensorND ChannelImpl::get_value(Handle handle) {
474
    MGB_LOCK_GUARD(m_spin);
475
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
476 477 478
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
479
    auto info = reinterpret_cast<TensorInfo*>(handle);
480
    // donnot use info->value_fetched, it's unsafe
481
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
482
    return wait_tensor(info, TensorProp::HostValue)->get_value();
483 484
}

485
TensorShape ChannelImpl::get_shape(Handle handle) {
486
    MGB_LOCK_GUARD(m_spin);
487
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
488 489 490
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
491 492 493 494
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
495
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
496 497 498 499
    mgb_assert(ret.ndim != 0);
    return ret;
}

500
DType ChannelImpl::get_dtype(Handle handle) {
501
    MGB_LOCK_GUARD(m_spin);
502
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
503 504 505
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
506
    auto info = reinterpret_cast<TensorInfo*>(handle);
507
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
508 509 510 511 512
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

513
CompNode ChannelImpl::get_device(Handle handle) {
514
    MGB_LOCK_GUARD(m_spin);
515
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
516 517 518
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
519
    auto info = reinterpret_cast<TensorInfo*>(handle);
520
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
521 522 523 524 525
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

526
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
527
    MGB_LOCK_GUARD(m_spin);
528
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
529 530 531
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
532
    auto info = reinterpret_cast<TensorInfo*>(handle);
533
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
534 535 536
}

void ChannelImpl::sync() {
537
    MGB_LOCK_GUARD(m_spin);
538
    mgb_assert(check_available(), "Channel already closed");
539 540 541 542
    sync_impl();
}

void ChannelImpl::sync_impl() {
543 544 545 546 547 548
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
549
    MGB_LOCK_GUARD(m_spin);
550 551 552 553
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
M
Megvii Engine Team 已提交
554
    for (auto* handle : valid_handles) {
555
        del_impl(handle);
556 557 558
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
559
    sync_impl();
560
    m_closed = true;
561 562
}

563
size_t ChannelImpl::get_option(std::string name) {
564
    MGB_LOCK_GUARD(m_spin);
565
    mgb_assert(check_available(), "Channel already closed");
566 567
    auto& state = get_channel_state();
    return state.options.get_option(name);
568 569
}

570
void ChannelImpl::set_option(std::string name, size_t value) {
571
    MGB_LOCK_GUARD(m_spin);
572
    mgb_assert(check_available(), "Channel already closed");
573 574
    auto& state = get_channel_state();
    state.options.set_option(name, value);
575 576 577 578 579 580 581 582 583
    // FIXME
    if (name == "enable_dtr_auto_drop" && value) {
        auto custom_allocator = [&](CompNode device, size_t size) {
            auto blob = Blob::make(device, size);
            alloc_tensor_with_evict(blob.get());
            return blob->storage();
        };
        BlobManager::inst()->set_allocator(custom_allocator);
    }
584 585 586 587 588 589 590 591 592 593
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), SetOption{name, value},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                SetOption{name, value},
        });
    }
594 595
}

596 597 598 599 600 601
void ChannelImpl::clear_candidates() {
    MGB_LOCK_GUARD(m_spin);
    mgb_assert(check_available(), "Channel already closed");
    m_dtr.candidates.clear();
}

602
TensorInfo* ChannelImpl::alloc() {
603
    auto& state = get_channel_state();
M
Megvii Engine Team 已提交
604
    auto info = [this] {
605
        MGB_LOCK_GUARD(m_pool_spin);
606
        return m_pool.alloc();
607 608 609
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
610
        size_t tensor_id = state.stack_manager.current()->next_id("tensor");
M
Megvii Engine Team 已提交
611 612
        info->name =
                state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
613
    }
614
    return info;
615 616
}

617
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
M
Megvii Engine Team 已提交
618
    m_valid_handle.insert(reinterpret_cast<Handle>(info));
619
    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
620
    info->status = TensorInfo::Allocated;
621
    info->desc = std::move(desc);
622 623
}

M
Megvii Engine Team 已提交
624
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
625 626
    if (!ptr->producer) {
        if (user) {
M
Megvii Engine Team 已提交
627 628 629 630
            mgb_log_warn(
                    "the input that produced tensor %p has been deleted, this drop "
                    "operation will be ignored",
                    ptr);
631 632 633 634 635 636 637
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
638
    ptr->status = TensorInfo::Dropped;
639 640 641
    release_tensor(ptr);
}

642
void ChannelImpl::free(TensorInfo* ptr) {
643 644
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
660
    MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
661
    SmallVector<TensorInfo*> inps;
662 663 664 665 666 667 668 669 670 671 672 673 674
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
675
    MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
676 677 678
}

void ChannelImpl::real_free(TensorInfo* ptr) {
679 680
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
681 682 683 684
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
685 686
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
687
        MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
688
    }
689
    MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
690
    ptr->status = TensorInfo::Deleted;
691
    MGB_LOCK_GUARD(m_pool_spin);
692 693 694
    m_pool.free(ptr);
}

695
ChannelImpl::ChannelImpl() : m_worker(this) {}
696

697 698 699
ChannelImpl::~ChannelImpl() {
    close();
}
700

701
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
702
    auto& state = get_worker_state();
703
    MGB_LOCK_GUARD(m_mutex);
704
    m_dtr.update_used_time(dest);
M
Megvii Engine Team 已提交
705 706
    MGB_RECORD_EVENT(
            TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
707
            ptr->dev_tensor(false).raw_ptr());
708
    // update tensor desc for static infer
709 710 711 712 713 714
    if (dest->desc.layout.ndim) {
        mgb_assert(
                dest->desc.layout.eq_shape(ptr->layout()),
                "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
                ptr->layout().to_string().c_str());
    }
715 716
    // in order to avoid performance impact,
    // memory forwarding is disabled when DTR is enabled
717
    if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
718 719
        ptr->to_contiguous_inplace();
    }
720
    dest->desc.comp_node = ptr->comp_node();
721
    dest->memory = ptr->blob()->size();
722
    dest->ptr = std::move(ptr);
723
    dest->evict_type = EvictType::NONE;
724
    dest->status = TensorInfo::Produced;
725 726
    if (dest->pinned == 0 &&
        dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
727 728
        m_dtr.insert_candidate(dest);
    }
729
    notify_tensor_unsafe(dest);
730 731
}

732
void ChannelImpl::release_tensor(TensorInfo* dest) {
733
    MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
734 735
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
736 737 738 739
    auto& state = get_worker_state();
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
        m_dtr.erase_candidate(dest);
    }
740 741
}

742
void ChannelImpl::regenerate(TensorInfo* dest) {
743
    if (dest->evict_type == EvictType::DROP) {
M
Megvii Engine Team 已提交
744 745
        auto&& path = dest->producer;
        m_apply_stack.push(
746 747
                {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
                 "dtr"});
M
Megvii Engine Team 已提交
748 749
        if (!m_applying)
            flush_apply_stack();
750 751 752
    }
}

753
void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
754 755
    using namespace ranges;
    using namespace ranges::views;
756
    auto& state = get_worker_state();
M
Megvii Engine Team 已提交
757 758
    bool profiling_device =
            Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
759
    uint64_t apply_id = cmd.id;
760
    SmallVector<TensorPtr> inputs;
761
    inputs.reserve(cmd.inputs.size());
762 763 764
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
765
        // refcnt ++, owners: [i->ptr, tensor_inputs]
766
        // tensor_inputs.push_back(i->ptr);
767
        inputs.push_back(i->ptr);
768
    }
M
Megvii Engine Team 已提交
769 770
    if (state.options.enable_dtr_auto_drop &&
        state.options.dtr_eviction_threshold > 0) {
771 772
        auto_evict(0);
    }
M
Megvii Engine Team 已提交
773
    auto apply_on_physical_tensor =
774
            [&](auto&& self, const OpDef& def, SmallVector<TensorPtr>&& inputs,
775 776
                SmallVector<LogicalTensorDesc>& output_descs,
                const bool& validated) -> SmallVector<TensorPtr> {
777
        if (def.trait()->make_forward_graph) {
778 779 780 781 782 783 784 785 786 787
            auto apply_functor = [&](std::shared_ptr<OpDef> op,
                                     SmallVector<TensorPtr> inputs,
                                     size_t nr_outputs) -> SmallVector<TensorPtr> {
                auto opname = op->trait()->make_name(*op);
                imperative_log_profile_begin(opname.c_str());
                auto outputs = self(self, *op, std::move(inputs), output_descs, false);
                imperative_log_profile_end(opname.c_str());
                return outputs;
            };
            auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
788 789
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
M
Megvii Engine Team 已提交
790
            for (auto&& input : inputs) {
791
                input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
792
            }
793
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
794 795
            auto outputs = forward_graph.apply<TensorPtr>(
                    inputs, apply_functor, const_functor);
796 797
            return outputs;
        }
798 799 800 801 802 803 804 805 806 807 808 809
        // Check Input Layout
        // Get the input layout constraints, and if the constraint is not satisfied
        // inplace update the layout and blob to make the tensor contiguous
        auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
        for (size_t idx = 0; idx < inputs.size(); ++idx) {
            auto&& layout_checker = constraints[idx];
            if (layout_checker) {
                inputs[idx]->to_contiguous_inplace(layout_checker);
            }
        }
        return OpDef::apply_on_physical_tensor(
                def, std::move(inputs), output_descs, validated);
810
    };
811
    MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
812 813 814 815
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
816 817 818
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
819
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
820 821 822
            }
        }
    }
M
Megvii Engine Team 已提交
823
    for (auto* input : cmd.inputs) {
824
        auto input_id = input->id;
825 826 827
        MGB_RECORD_EVENT(OpInputEvent, input_id);
        MGB_RECORD_EVENT(TensorUsageEvent, input_id);
        MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
828 829
    }
    // Before wait
M
Megvii Engine Team 已提交
830
    // TODO: split operator wait and execute so that OpWait could be corrected recorded.
831
    // Before execute
M
Megvii Engine Team 已提交
832
    for (auto&& [device, kernel_id] : kernels) {
833
        MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
M
Megvii Engine Team 已提交
834
        MGB_RECORD_EVENT_IF(
835 836
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
837 838
    }
    // Apply op
839
    SmallVector<LogicalTensorDesc> output_descs;
840 841 842 843 844 845 846
    bool validated = cmd.validated;
    if (!state.options.enable_dtr_auto_drop) {
        for (auto i : cmd.outputs) {
            output_descs.push_back(i->desc);
        }
    } else {
        validated = false;
847
    }
848
    // Here std::move is REQUIRED for removing duplicated references.
849
    auto outputs = apply_on_physical_tensor(
850
            apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
851
            validated);
852
    // After execute
M
Megvii Engine Team 已提交
853 854
    for (auto&& [device, kernel_id] : kernels) {
        MGB_RECORD_EVENT_IF(
855 856
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
857
        MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
858 859
    }
    // End profiling operator
860 861
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
862
        auto output = cmd.outputs[i];
863
        if (mgb_unlikely(output == nullptr)) {
864 865
            MGB_RECORD_EVENT(OpOutputEvent, 0);
            MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
866
        } else if (mgb_unlikely(output->ptr != nullptr)) {
867 868
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
869
        } else {
870
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
871
            produce_tensor(output, outputs[i]);
872
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
873
            sample_on_device(output->desc.comp_node, false);
874 875 876 877 878 879 880 881
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
882
        for (auto i : outputs) {
883
            estimate_compute_time += i->blob()->size();
884 885 886 887 888 889 890
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
891
        m_dtr.unpin(cmd.inputs, state);
892
    }
893
    MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
894
    // End profiling operator
895
}
896

897 898
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
899
    auto& state = get_worker_state();
900
    while (!m_apply_stack.empty()) {
M
Megvii Engine Team 已提交
901 902
        auto& [cmd, idx, recomp, reason] =
                m_apply_stack.top();  // cmd.inputs[0~idx-1] is in memory
903 904 905 906 907
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
M
Megvii Engine Team 已提交
908 909
                MGB_RECORD_EVENT(
                        TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
910 911 912
            }
        }
        bool regen = false;
M
Megvii Engine Team 已提交
913
        for (size_t i = idx; i < cmd.inputs.size(); i++) {
914 915 916 917 918 919
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
M
Megvii Engine Team 已提交
920
                regenerate(p);  // add ApplyOp to the stack
921 922 923 924
                regen = true;
                break;
            }
        }
M
Megvii Engine Team 已提交
925 926
        if (regen)
            continue;
927
        // the required input tensors are already in memory
M
Megvii Engine Team 已提交
928 929
        auto [cmd_backup, recomp_backup, reason_backup] =
                std::make_tuple(cmd, recomp, reason);
930
        m_apply_stack.pop();
931
        do_apply_op(cmd_backup, reason_backup);
932
        if (recomp_backup) {
M
Megvii Engine Team 已提交
933 934 935
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, recomp_backup->id,
                    TensorCommandKind::ReGen);
936 937
            for (auto o : cmd_backup.outputs) {
                if (o) {
938 939 940 941
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
942
    }
943
    m_applying = false;
944 945
}

946
bool ChannelImpl::auto_evict(size_t force_num) {
947
    auto& state = get_worker_state();
948
    if (!m_dtr.comp_node.valid()) {
949
        return false;
950 951
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
952
    size_t flag = false;
M
Megvii Engine Team 已提交
953 954 955
    while ((state.options.dtr_eviction_threshold > 0 &&
            current_memory > state.options.dtr_eviction_threshold) ||
           force_num > 0) {
956
        MGB_RECORD_EVENT(AutoEvictEvent);
957
        sample_on_device(m_dtr.comp_node, false);
958
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
959
        if (!best) {
960
            MGB_RECORD_EVENT(AutoEvictFinishEvent);
961 962 963 964
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
965
            if (force_num > 0) {
M
Megvii Engine Team 已提交
966
                force_num--;
967 968
            }
            flag = true;
969 970 971 972
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
973
        }
974
        sample_on_device(m_dtr.comp_node, false);
975
        MGB_RECORD_EVENT(AutoEvictFinishEvent);
976
    }
977
    return flag;
978 979
}

980 981
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
M
Megvii Engine Team 已提交
982
    for (auto* user : users) {
983 984
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
M
Megvii Engine Team 已提交
985 986 987 988 989
        for (auto* output : outputs) {
            // When a `ComputePath` is detach from it's input,
            // there is no need to reserve it,
            // so we detach all output of this path
            // to decrease it's `ref_cnt` to zero.
990 991 992 993 994
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
M
Megvii Engine Team 已提交
995 996
            for (auto* input : inputs) {
                input->ref_cnt--;
997
            }
998
        }
999
        // now user is dead
1000
    }
1001
    mgb_assert(dest->users.empty(), "ComputePath leaking");
1002 1003
}

1004 1005 1006 1007
bool ChannelImpl::check_available() {
    return !m_closed;
}

1008 1009 1010 1011 1012
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
1013
    MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
1014
    bool require_host = prop == TensorProp::HostValue;
M
Megvii Engine Team 已提交
1015
    auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
1016 1017
    bool wait_host = false;
    if (require_host && !host_available()) {
1018 1019
        // avoid dead lock
        lock.unlock();
1020 1021 1022 1023 1024 1025 1026 1027 1028 1029
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), GetValue{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    GetValue{info},
            });
        }
1030
        lock.lock();
1031
        wait_host = true;
1032
    }
1033 1034
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
1035
        return require_host ? host_available() : static_cast<bool>(info->ptr);
1036
    });
1037
    MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
1038
    m_waitee = nullptr;
1039
    if (wait_host) {
1040 1041 1042
        auto err = info->ptr->comp_node().check_async_error();
        mgb_assert(!err, "%s", err->what());
    }
1043 1044 1045 1046 1047
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
1048
        MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
1049
        m_cv.notify_all();
1050
    }
1051 1052 1053 1054
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
M
Megvii Engine Team 已提交
1055
    for (auto* handle : m_valid_handle) {
1056 1057
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
1058
    }
1059
    return valid_tensors;
1060 1061
}

1062
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
1063
    bool in_worker = (get_worker_tid() == std::this_thread::get_id());
1064 1065 1066 1067 1068 1069
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
M
Megvii Engine Team 已提交
1070 1071
            if (!evict_suc)
                return false;
1072 1073 1074 1075
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
1076 1077 1078
    if (in_worker) {
        reserve_size(x->size());
    }
1079
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
1080 1081
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
1082 1083 1084 1085 1086 1087 1088 1089
        if (in_worker) {
            while (!suc) {
                if (!auto_evict(1)) {
                    break;
                }
                MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
                MGB_CATCH(MemAllocError&, { continue; });
                suc = true;
1090 1091 1092 1093
            }
        }
        if (!suc) {
            set_log_level(pre_level);
M
Megvii Engine Team 已提交
1094 1095 1096
            mgb_log_warn(
                    "reallocating all cuda memory to alleviate fragmentation, the "
                    "performance may be affected");
1097
            set_log_level(LogLevel::NO_LOG);
1098
            imperative_log_profile_begin("defrag");
1099
            BlobManager::inst()->defrag(x->comp_node());
1100
            imperative_log_profile_end("defrag");
1101
            BlobManager::inst()->alloc_direct(x, x->size());
1102 1103 1104 1105 1106
        }
    });
    set_log_level(pre_level);
}

1107
void ChannelImpl::process_one_task(Command& icmd) {
1108 1109
    using namespace ranges;
    using namespace ranges::views;
1110
    auto& state = get_worker_state();
1111
    auto& options = state.options;
M
Megvii Engine Team 已提交
1112
    // TODO: remove std::visit for support osx 10.12
1113
    auto cmd_visitor = [&](const auto& cmd) {
M
Megvii Engine Team 已提交
1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, Put>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
                                      : Tensor::make(cmd.value);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            produce_tensor(cmd.dest, std::move(value));
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
            sample_on_device(cmd.dest->desc.comp_node, false);
        } else if constexpr (std::is_same_v<T, ApplyOp>) {
            for (auto& i : cmd.inputs) {
1131
                if (mgb_unlikely(i->invalid)) {
M
Megvii Engine Team 已提交
1132 1133 1134
                    MGB_LOCK_GUARD(m_mutex);
                    for (auto& i : cmd.outputs) {
                        i->invalid = true;
1135
                    }
M
Megvii Engine Team 已提交
1136 1137 1138
                    return;
                }
            }
1139 1140 1141 1142 1143 1144 1145 1146
            if (state.options.enable_dtr_auto_drop) {
                m_apply_stack.push({cmd, 0, nullptr, "cmd"});
                flush_apply_stack();
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
                        continue;
                    }
M
Megvii Engine Team 已提交
1147
                    output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1148
                }
1149 1150
            } else {
                do_apply_op(cmd, "cmd");
M
Megvii Engine Team 已提交
1151 1152 1153 1154 1155 1156 1157
            }
            if (state.options.enable_drop && state.options.record_computing_path) {
                auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                    auto& input = std::get<0>(tuple2);
                    auto& output = std::get<1>(tuple2);
                    if (!input->ptr || !output->ptr) {
                        return false;
1158
                    }
M
Megvii Engine Team 已提交
1159 1160 1161 1162 1163 1164 1165
                    return input->ptr->blob()->storage() ==
                           output->ptr->blob()->storage();
                };
                // FIXME: do not use opname as identifier
                auto get_name = [](const OpDef& opdef) {
                    if (auto attr = opdef.try_cast_final<OprAttr>()) {
                        return attr->type.c_str();
1166
                    }
M
Megvii Engine Team 已提交
1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179
                    return opdef.dyn_typeinfo()->name;
                };

                auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
                    return info->desc.comp_node != comp_node;
                };

                bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                bool inplace =
                        any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);

                if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                    TensorInfo::ComputePath::make(
1180
                            cmd.id, cmd.op, cmd.inputs, cmd.outputs);
M
Megvii Engine Team 已提交
1181 1182
                    size_t detach_cnt = 0;
                    if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
1183
                        cmd.outputs.size() == 6) {
M
Megvii Engine Team 已提交
1184 1185
                        cmd.outputs[0]->detach_producer();  // detach running_mean
                        cmd.outputs[1]->detach_producer();  // detach running_var
1186
                        for (auto input : cmd.inputs) {
M
Megvii Engine Team 已提交
1187
                            input->ref_cnt -= 2;
1188 1189
                        }
                    }
M
Megvii Engine Team 已提交
1190 1191 1192 1193 1194 1195 1196
                    for (auto output : cmd.outputs) {
                        if (output->producer &&
                            !output->size_exceeds_thd(
                                    state.options.dtr_evictee_minimum_size)) {
                            output->detach_producer();
                            detach_cnt++;
                        }
1197
                    }
M
Megvii Engine Team 已提交
1198 1199
                    for (auto input : cmd.inputs) {
                        input->ref_cnt -= detach_cnt;
1200
                    }
1201
                }
1202
            }
M
Megvii Engine Team 已提交
1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218
        } else if constexpr (std::is_same_v<T, Del>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
            CompNode device = cmd.dest->desc.comp_node;
            uint64_t tensor_id = cmd.dest->id;
            free(cmd.dest);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
            sample_on_device(device, false);
        } else if constexpr (std::is_same_v<T, GetValue>) {
            if (cmd.dest->invalid)
                return;
            imperative_log_profile_begin("GetValue");
            if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                regenerate(cmd.dest);
            }
            cmd.dest->ptr->fetch_value();
1219
            MGB_LOCK_GUARD(m_mutex);
M
Megvii Engine Team 已提交
1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236
            notify_tensor_unsafe(cmd.dest);
            imperative_log_profile_end("GetValue");
        } else if constexpr (std::is_same_v<T, Drop>) {
            if (cmd.dest->invalid)
                return;
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
            do_drop(cmd.dest, true);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
        } else if constexpr (std::is_same_v<T, SetOption>) {
            options.set_option(cmd.key, cmd.value);
        } else if constexpr (std::is_same_v<T, StartProfile>) {
            MGB_RECORD_EVENT(StartProfileEvent);
            CompNode::sync_all();
            for (auto* info : cmd.capture_tensors) {
                MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                if (info->status == TensorInfo::Produced) {
1237
                    // TODO: handle drop
M
Megvii Engine Team 已提交
1238 1239 1240
                    MGB_RECORD_EVENT(
                            TensorProduceEvent, info->id, info->desc.layout,
                            info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
1241 1242
                }
            }
M
Megvii Engine Team 已提交
1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257
            CompNode::foreach ([&](CompNode device) {
                sample_on_device(device, true);
                MGB_RECORD_EVENT_IF(
                        (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                        Timer::record_device(device));
            });
            MGB_RECORD_EVENT(StartProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, StopProfile>) {
            MGB_RECORD_EVENT(StopProfileEvent);
            for (auto* info : cmd.escape_tensors) {
                bool has_value = info->status == TensorInfo::Produced;
                if (has_value) {
                    MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
                }
                MGB_RECORD_EVENT(TensorEraseEvent, info->id);
1258
            }
M
Megvii Engine Team 已提交
1259 1260 1261 1262 1263 1264 1265 1266 1267
            CompNode::foreach (
                    [&](CompNode device) { sample_on_device(device, true); });
            MGB_RECORD_EVENT(StopProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, PushScope>) {
            MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
        } else if constexpr (std::is_same_v<T, PopScope>) {
            MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
        } else {
            static_assert(!std::is_same_v<T, T>);
1268
        }
M
Megvii Engine Team 已提交
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295
    };
    std::visit(
            [&](const auto& cmd) {
                using T = std::decay_t<decltype(cmd)>;
                if (!options.catch_worker_execption) {
                    cmd_visitor(cmd);
                    return;
                }
                try {
                    cmd_visitor(cmd);
                } catch (...) {
                    MGB_LOCK_GUARD(m_mutex);
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        for (auto oup : cmd.outputs) {
                            oup->invalid = true;
                        }
                    } else if constexpr (std::is_same_v<T, Put>) {
                        cmd.dest->invalid = true;
                    }
                    m_worker_exc = std::current_exception();
                    MGB_RECORD_EVENT(WorkerExceptionEvent);
                    if (m_waitee) {
                        notify_tensor_unsafe(m_waitee);
                    }
                }
            },
            icmd.data);
1296 1297 1298 1299
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1300 1301
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1302 1303
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1304 1305 1306 1307 1308
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1309 1310
    }
}
1311

1312
void ChannelImpl::start_profile() {
1313
    MGB_LOCK_GUARD(m_spin);
1314
    mgb_assert(check_available(), "Channel already closed");
1315 1316
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
1317 1318 1319 1320 1321 1322 1323 1324 1325 1326
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StartProfile{std::move(capture_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StartProfile{std::move(capture_tensors)},
            });
        }
1327
    }
1328 1329
}

1330
void ChannelImpl::stop_profile() {
1331
    MGB_LOCK_GUARD(m_spin);
1332
    mgb_assert(check_available(), "Channel already closed");
1333 1334
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
1335 1336 1337 1338 1339 1340 1341 1342 1343 1344
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StopProfile{std::move(escape_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StopProfile{std::move(escape_tensors)},
            });
        }
1345
    }
1346 1347 1348
}

void ChannelImpl::push_scope(std::string name) {
1349
    MGB_LOCK_GUARD(m_spin);
1350
    mgb_assert(check_available(), "Channel already closed");
1351
    auto& state = get_channel_state();
1352
    state.stack_manager.enter(name);
1353
    MGB_RECORD_EVENT(ScopeEvent, name);
1354 1355 1356 1357 1358 1359 1360 1361 1362 1363
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PushScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PushScope{name},
        });
    }
1364 1365 1366
}

void ChannelImpl::pop_scope(std::string name) {
1367
    MGB_LOCK_GUARD(m_spin);
1368
    mgb_assert(check_available(), "Channel already closed");
1369
    auto& state = get_channel_state();
1370
    state.stack_manager.exit(name);
1371
    MGB_RECORD_EVENT(ScopeFinishEvent, name);
1372 1373 1374 1375 1376 1377 1378 1379 1380 1381
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PopScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PopScope{name},
        });
    }
1382 1383
}

1384
void ChannelImpl::assert_in_channel() {
M
Megvii Engine Team 已提交
1385 1386 1387
    mgb_assert(
            get_worker_tid() != std::this_thread::get_id(),
            "this method cannot be called in worker thread");
1388 1389 1390
}

void ChannelImpl::assert_in_worker() {
M
Megvii Engine Team 已提交
1391 1392 1393
    mgb_assert(
            get_worker_tid() == std::this_thread::get_id(),
            "this method can only be called in worker thread");
1394 1395
}

1396
void ChannelImpl::sample_on_device(CompNode device, bool force) {
1397 1398 1399
    if (!Profiler::is_profiling()) {
        return;
    }
1400 1401
    if (!force) {
        thread_local int last_sample_id = 0;
1402
        int sample_rate = Profiler::get_option("sample_rate", 0);
1403 1404 1405 1406
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
1407
    MGB_RECORD_EVENT(SampleDeviceEvent, device);
1408
    auto [total, free] = device.get_mem_status_bytes();
1409
    MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
1410 1411
}

1412 1413 1414
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
1415
        erase_candidate(i);
1416 1417 1418
    }
}

1419 1420
void ChannelImpl::DynamicSublinear::unpin(
        const SmallVector<TensorInfo*>& vec, WorkerState& state) {
1421 1422
    for (auto i : vec) {
        i->unpin();
1423 1424 1425 1426 1427
        if (i->pinned == 0 &&
            i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
            i->cand_index == UINT_MAX) {
            insert_candidate(i);
        }
1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

M
Megvii Engine Team 已提交
1474 1475
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
        bool enable_dtr_sqrt_sampling = false) {
1476 1477 1478
    if (candidates.empty())
        return nullptr;

1479 1480
    double min_msps = -1;
    TensorInfo* best = nullptr;
1481 1482
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
M
Megvii Engine Team 已提交
1483 1484
        while (sz * sz <= candidates.size())
            sz++;
1485
        sz--;
1486 1487 1488
    } else {
        sz = candidates.size();
    }
1489 1490 1491 1492 1493 1494 1495

    size_t ti = rand() % sz;
    for (size_t vi = 0; vi < sz; vi++) {
        if (!enable_dtr_sqrt_sampling) {
            ti = vi;
        }
        auto i = candidates[ti];
1496
        if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
1497
            double neighbor_cost = estimate_neighbor_cost(i);
M
Megvii Engine Team 已提交
1498 1499 1500 1501
            size_t begin_ptr =
                    reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(
                    begin_ptr, begin_ptr + i->ptr->blob()->size());
1502
            double free_mem = side_info.first + side_info.second;
M
Megvii Engine Team 已提交
1503 1504
            double msps = i->eval_func(
                    neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
1505 1506 1507 1508 1509
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1510 1511 1512 1513 1514
        if (enable_dtr_sqrt_sampling) {
            ti += rand() % sz;
            if (ti > candidates.size())
                break;
        }
1515 1516 1517 1518
    }
    return best;
}

M
Megvii Engine Team 已提交
1519 1520
void ChannelImpl::DynamicSublinear::merge(
        std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
1521 1522 1523 1524 1525 1526 1527 1528 1529
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

M
Megvii Engine Team 已提交
1530 1531
std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
        std::shared_ptr<DsuNode>& x) {
1532 1533 1534 1535 1536 1537 1538 1539 1540
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
1541 1542 1543 1544 1545 1546
    // tensor to be inserted must be brand new
    mgb_assert(
            ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
            ptr->cand_index);
    ptr->cand_index = candidates.size();
    candidates.push_back(ptr);
1547 1548 1549 1550 1551 1552
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
1553 1554 1555 1556 1557 1558
    // close dtr will just clear candidates, so nothing to erase
    if (candidates.empty()) {
        ptr->cand_index = UINT_MAX;
        return;
    }
    // some tensors may be erased already, just skip them
1559 1560 1561 1562 1563 1564
    if (ptr->cand_index != UINT_MAX) {
        std::swap(candidates[ptr->cand_index], candidates.back());
        candidates[ptr->cand_index]->cand_index = ptr->cand_index;
        candidates.pop_back();
        ptr->cand_index = UINT_MAX;
    }
1565 1566 1567 1568 1569
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}