interpreter_impl.cpp 53.7 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21
#include "megbrain/imperative/ops/utility.h"
22 23
#include "megbrain/imperative/utils/to_string.h"

24
#include "../blob_manager_impl.h"
25 26 27
#include "../event_pool.h"
#include "../op_trait.h"

28 29 30 31 32
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

33
#define RECORD_EVENT(type, ...) \
34 35
    if (Profiler::is_profiling()) { \
        Profiler::record<type>(type{__VA_ARGS__}); \
36 37 38
    } \


39 40 41 42 43 44 45 46 47 48
namespace {
    auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
        SmallVector<uint64_t> tid;
        for (auto* ptinfo: tinfo) {
            tid.push_back(ptinfo->id);
        }
        return tid;
    };
}

49 50 51 52
namespace mgb {
    using namespace profiler;
}

53 54 55 56 57
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
58 59 60 61 62 63 64

namespace mgb {

/**
 * USAGE
 *
 *   header:
65
 *     namespace mgb { void imperative_log_profile(const char* message); }
66 67 68 69 70
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
71
SYMBOL_EXPORT
72 73 74 75
void imperative_log_profile_begin(const char* message) {
    RECORD_EVENT(CustomEvent, std::string{message});
}

76
SYMBOL_EXPORT
77 78 79 80
void imperative_log_profile_end(const char* message) {
    RECORD_EVENT(CustomFinishEvent, std::string{message});
}

81
SYMBOL_EXPORT
82 83 84 85 86 87 88
void imperative_log_profile(const char* message){
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

}

89 90 91 92 93 94 95 96 97 98 99 100 101 102
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

103 104 105 106 107 108 109 110 111 112
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
    OpDef::set_allocator([&](CompNode device, size_t size) {
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
    });
}

113
// Do not use m_xxx_state directly
114 115 116
#define m_channel_state
#define m_worker_state

117 118 119 120 121 122 123 124 125
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

126
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
127
    MGB_LOCK_GUARD(m_spin);
128
    mgb_assert(check_available(), "Channel already closed");
129 130 131 132 133 134 135 136
    auto& state = get_channel_state();
    state.scopes.push("Put");
    auto info = put_impl(value, no_cache);
    state.scopes.pop("Put");
    return info;
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
137 138 139 140 141
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
142
    auto info = alloc();
143
    init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
144
    info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
145
    info->h_value = value;
146
    m_buffer.enqueue(Put{info, value, no_cache});
147
    if (m_async_level == 0) {
148
        sync_impl();
149 150
        info->desc.comp_node.sync();
    }
151 152 153
    return info;
}

154
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
155
    MGB_LOCK_GUARD(m_spin);
156
    mgb_assert(check_available(), "Channel already closed");
157 158 159 160
    return put_impl(data, hvalue);
}
TensorInfo* ChannelImpl::put_impl(const DeviceTensorND& data, const HostTensorND& hvalue) {
    auto& state = get_channel_state();
161
    state.scopes.push("Put");
M
Megvii Engine Team 已提交
162
    auto info = alloc();
163 164
    RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put);
    init(info, {data.layout(), data.comp_node()});
165
    info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
166
    info->ptr = Tensor::make(data, hvalue);
167 168 169 170
    RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
    info->status = TensorInfo::Produced;
    RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put);
    state.scopes.pop("Put");
M
Megvii Engine Team 已提交
171 172 173
    return info;
}

174
void ChannelImpl::del(Handle handle) {
175
    MGB_LOCK_GUARD(m_spin);
176 177 178
    if (!check_available()){
        return;
    }
179 180 181 182
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
183 184 185 186
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
    m_buffer.enqueue(Del{info});
187 188
}

189
void ChannelImpl::swap_in(Handle handle) {
190
    MGB_LOCK_GUARD(m_spin);
191
    mgb_assert(check_available(), "Channel already closed");
192 193
    auto& state = get_channel_state();
    if (state.options.enable_swap) {
194 195
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
196 197
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapIn{info});
198 199 200
    }
}

201
void ChannelImpl::swap_out(Handle handle) {
202
    MGB_LOCK_GUARD(m_spin);
203
    mgb_assert(check_available(), "Channel already closed");
204 205
    auto& state = get_channel_state();
    if (state.options.enable_swap) {
206 207
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
208 209
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapOut{info});
210 211 212
    }
}

213
void ChannelImpl::drop(Handle handle) {
214
    MGB_LOCK_GUARD(m_spin);
215
    mgb_assert(check_available(), "Channel already closed");
216 217
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
218 219
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
220 221
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(Drop{info});
222 223 224
    }
}

225
void ChannelImpl::dispatch_default_cpu(
226
        std::shared_ptr<OpDef> op,
227 228 229
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
230
    auto& state = get_channel_state();
231 232 233 234

    auto name = op->trait()->make_name(*op);
    state.scopes.push(name);

235
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
236
    RECORD_EVENT(ShapeInferEvent, validated);
237

238 239 240
    SmallVector<DeviceTensorND> input_tensornds;
    input_tensornds.reserve(input_descs.size());
    CompNode output_cn;
241 242
    {
        MGB_LOCK_GUARD(m_mutex);
243
        for (auto&& info : input_infos) {
244
            auto input_cn = info->desc.comp_node;
245
            if (!output_cn.valid()) {
246 247 248 249 250 251 252
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
                input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
253
            } else {
254
                // It's OK for SwapOut. We assign h_value before drop ptr
255 256
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
257 258 259 260 261 262 263 264 265 266 267 268 269 270
            }
        }
    }

    outputs->reserve(output_descs.size());
    SmallVector<DeviceTensorND> output_tensornds;
    output_tensornds.reserve(output_descs.size());
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
        output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
    }

271
    uint64_t op_id = Profiler::next_id();
272

273 274 275 276 277 278 279
    OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);

    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());
    for (auto&& tensornd : output_tensornds) {
        HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
            .proxy_to_comp_node(output_cn);
280
        // use `put` for consistency
281
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
282
        mgb_assert(info->desc.layout.ndim != 0);
283 284 285
        output_infos.push_back(info);
        outputs->push_back(info);
    }
286 287 288 289 290 291 292 293 294
    auto op_info_getter = [op]{
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
        for (auto&& [key, value]: props) {
            op_info[key] = value;
        }
        return op_info;
    };
    RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos));
295 296

    state.scopes.pop(name);
297
}
298

299 300 301 302 303
void ChannelImpl::dispatch_kernel(
        std::shared_ptr<OpDef> op,
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
304
    auto& state = get_channel_state();
305 306 307 308 309
    auto& options = state.options;

    auto name = op->trait()->make_name(*op);
    state.scopes.push(name);

310
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
311
    RECORD_EVENT(ShapeInferEvent, validated);
312

313
    ApplyOp cmd{Profiler::next_id(), std::move(op)};
314
    cmd.inputs = std::move(input_infos);
315
    cmd.outputs.reserve(output_descs.size());
316
    outputs->reserve(output_descs.size());
317 318
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
319
        auto info = alloc();
320
        init(info, desc);
321 322 323 324 325
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
                .proxy_to_comp_node(desc.comp_node);
        }
326
        cmd.outputs.push_back(info);
327
        outputs->push_back(info);
328
    }
329 330 331 332 333 334 335 336 337
    auto op_info_getter = [op=cmd.op]{
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
        for (auto&& [key, value]: props) {
            op_info[key] = value;
        }
        return op_info;
    };
    RECORD_EVENT(OpDispatchEvent, cmd.id, cmd.op->trait()->name, op_info_getter, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
338
    m_buffer.enqueue(std::move(cmd));
339
    if (!validated && options.async_level == 1) {
340
        sync_impl();
341
    } else if (options.async_level == 0) {
342
        sync_impl();
343
        // check device error
344
        for (auto&& oup : *outputs) {
345 346
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
347
        }
348
    }
349
    state.scopes.pop(name);
350 351 352 353 354
}

SmallVector<Handle> ChannelImpl::apply_op(
        std::shared_ptr<OpDef> op,
        const SmallVector<Handle>& inputs) {
355
    MGB_LOCK_GUARD(m_spin);
356
    mgb_assert(check_available(), "Channel already closed");
357 358 359 360 361 362
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
        std::shared_ptr<OpDef> op,
        const SmallVector<Handle>& inputs) {
363
    auto& state = get_channel_state();
364 365 366 367 368 369 370 371 372 373 374 375
    for (auto i : inputs) {
        mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
                "invalid handle: %p", i);
    }
    SmallVector<TensorInfo*> input_infos;
    input_infos.reserve(inputs.size());
    SmallVector<LogicalTensorDesc> input_descs;
    input_descs.reserve(inputs.size());
    {
        MGB_LOCK_GUARD(m_mutex);
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
376
            mgb_assert(!info->invalid, "an input tensor is unusable due to previous error");
377 378 379 380 381 382
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
383
    DispatchMode dispatch_mode = state.options.enable_host_compute
384 385 386
            ? OpDef::decide_dispatch_mode(*op, input_descs)
            : DispatchMode::KERNEL;
    switch (dispatch_mode) {
387 388 389 390 391 392 393 394 395
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
396 397 398
    return outputs;
}

399
HostTensorND ChannelImpl::get_value(Handle handle) {
400
    MGB_LOCK_GUARD(m_spin);
401
    mgb_assert(check_available(), "Channel already closed");
402 403 404
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
405
    // donnot use info->value_fetched, it's unsafe
406
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
407
    return wait_tensor(info, TensorProp::HostValue)->get_value();
408 409
}

410
TensorShape ChannelImpl::get_shape(Handle handle) {
411
    MGB_LOCK_GUARD(m_spin);
412
    mgb_assert(check_available(), "Channel already closed");
413 414 415 416 417 418
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
419
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
420 421 422 423
    mgb_assert(ret.ndim != 0);
    return ret;
}

424
DType ChannelImpl::get_dtype(Handle handle) {
425
    MGB_LOCK_GUARD(m_spin);
426
    mgb_assert(check_available(), "Channel already closed");
427 428 429
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
430
    RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
431 432 433 434 435
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

436
CompNode ChannelImpl::get_device(Handle handle) {
437
    MGB_LOCK_GUARD(m_spin);
438
    mgb_assert(check_available(), "Channel already closed");
439 440 441
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
442
    RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
443 444 445 446 447
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

448
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
449
    MGB_LOCK_GUARD(m_spin);
450
    mgb_assert(check_available(), "Channel already closed");
451 452 453
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
454
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
455 456 457
}

void ChannelImpl::sync() {
458
    MGB_LOCK_GUARD(m_spin);
459
    mgb_assert(check_available(), "Channel already closed");
460 461 462 463
    sync_impl();
}

void ChannelImpl::sync_impl() {
464
    m_buffer.flush();
465 466 467 468 469 470
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
471
    MGB_LOCK_GUARD(m_spin);
472 473 474 475 476
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
    for (auto* handle: valid_handles) {
477
        del_impl(handle);
478 479 480
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
481
    sync_impl();
482
    m_closed = true;
483 484
}

485
size_t ChannelImpl::get_option(std::string name) {
486
    MGB_LOCK_GUARD(m_spin);
487
    mgb_assert(check_available(), "Channel already closed");
488 489
    auto& state = get_channel_state();
    return state.options.get_option(name);
490 491
}

492
void ChannelImpl::set_option(std::string name, size_t value) {
493
    MGB_LOCK_GUARD(m_spin);
494
    mgb_assert(check_available(), "Channel already closed");
495 496
    auto& state = get_channel_state();
    state.options.set_option(name, value);
497
    m_buffer.enqueue(SetOption{name, value});
498 499 500
}

TensorInfo* ChannelImpl::alloc() {
501
    auto& state = get_channel_state();
502 503 504 505 506 507 508 509
    auto info = [this]{
        MGB_LOCK_GUARD(m_mutex);
        return m_pool.alloc();
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
        info->name = state.scopes.next_tensor_name();
    }
510
    return info;
511 512
}

513 514 515 516 517
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
    m_valid_handle.insert(info);
    RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
    info->status = TensorInfo::Allocated;
    info->desc = std::move(desc);
518 519 520
    info->mem_desc.layout = info->desc.layout;
    info->mem_desc.cn = info->desc.comp_node;
    info->mem_desc.offset = 0;
521 522
}

523 524 525 526 527 528 529 530 531 532 533 534

void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) {
    if (!ptr->producer) {
        if (user) {
            mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", ptr);
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
535
    ptr->status = TensorInfo::Dropped;
536 537 538
    release_tensor(ptr);
}

539
void ChannelImpl::free(TensorInfo* ptr) {
540 541
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
557 558
    RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandEvent::RecFree);
    SmallVector<TensorInfo*> inps;
559 560 561 562 563 564 565 566 567 568 569 570 571
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
572
    RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandFinishEvent::RecFree);
573 574 575
}

void ChannelImpl::real_free(TensorInfo* ptr) {
576 577
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
578 579 580 581
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
582 583 584 585 586 587
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
        RECORD_EVENT(TensorReleaseEvent, ptr->id);
    }
    RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
    ptr->status = TensorInfo::Deleted;
588
    MGB_LOCK_GUARD(m_mutex);
589 590 591
    m_pool.free(ptr);
}

592
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
593

594 595 596
ChannelImpl::~ChannelImpl() {
    close();
}
597

598
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
599
    auto& state = get_worker_state();
600
    MGB_LOCK_GUARD(m_mutex);
601
    m_dtr.update_used_time(dest);
602
    RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr());
603 604 605
    // update tensor desc for static infer
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
606
    dest->memory = ptr->blob()->size();
607
    dest->ptr = std::move(ptr);
608
    dest->evict_type = EvictType::NONE;
609
    dest->status = TensorInfo::Produced;
610
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
611 612
        m_dtr.insert_candidate(dest);
    }
613
    notify_tensor_unsafe(dest);
614 615
}

616
void ChannelImpl::release_tensor(TensorInfo* dest) {
617
    RECORD_EVENT(TensorReleaseEvent, dest->id);
618 619 620 621
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
}

622
void ChannelImpl::regenerate(TensorInfo* dest) {
623
    if (dest->evict_type == EvictType::DROP) {
624 625 626
        auto &&path = dest->producer;
        m_apply_stack.push({ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest});
        if (!m_applying) flush_apply_stack();
627
    } else if (dest->evict_type == EvictType::SWAP) {
628
        RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen);
629
        produce_tensor(dest, Tensor::make(dest->h_value));
630
        RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen);
631 632 633
    }
}

634 635 636
void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
    using namespace ranges;
    using namespace ranges::views;
637
    auto& state = get_worker_state();
638
    bool profiling_device = Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
639
    uint64_t apply_id = cmd.id;
640 641 642 643 644 645
    struct TensorWithDesc {
        TensorPtr tensor;
        MemoryDesc desc;
    };
    SmallVector<TensorWithDesc> inputs;
    inputs.reserve(cmd.inputs.size());
646 647 648
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
649
        // refcnt ++, owners: [i->ptr, tensor_inputs]
650 651
        // tensor_inputs.push_back(i->ptr);
        inputs.push_back({i->ptr, i->mem_desc});
652
    }
653 654 655
    if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
        auto_evict(0);
    }
656 657 658
    auto apply_on_physical_tensor = [&](auto&& self, const OpDef& def, SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> {
        auto apply_functor = [&](std::shared_ptr<OpDef> op, SmallVector<TensorWithDesc> inputs, size_t nr_outputs) -> SmallVector<TensorWithDesc> {
            auto opname = op->trait()->make_name(*op);
659
            imperative_log_profile_begin(opname.c_str());
660
            auto outputs = self(self, *op, inputs);
661
            imperative_log_profile_end(opname.c_str());
662 663 664 665 666 667 668 669 670 671
            return outputs;
        };
        auto const_functor = [&](TensorPtr value) -> TensorWithDesc {
            return {value, MemoryDesc{value->layout(), 0, value->comp_node(), StorageIdentifier::make()}};
        };
        if (def.trait()->make_forward_graph) {
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
            for (auto&& input: inputs) {
                input_descs.push_back({{{}, input.tensor->dtype()}, input.tensor->comp_node()});
672
            }
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
            auto outputs = forward_graph.apply(inputs, apply_functor, const_functor);
            return outputs;
        }
        SmallVector<TensorPtr> input_tensors;
        SmallVector<MemoryDesc> input_descs;
        for (auto&& input: inputs) {
            input_tensors.push_back(input.tensor);
            input_descs.push_back(input.desc);
        }
        auto [output_descs, output_tensors, workspaces] = init_output_and_workspace(def, input_tensors, input_descs);
        if (!output_descs.empty()) {
            OpDef::execute(def, input_tensors, output_tensors, workspaces);
        } else {
            output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors);
            for (auto&& output_tensor: output_tensors) {
                output_descs.push_back(MemoryDesc{output_tensor->layout(), 0, output_tensor->comp_node(), StorageIdentifier::make()});
690 691
            }
        }
692 693 694 695 696 697
        SmallVector<TensorWithDesc> outputs;
        for (auto&& [output_tensor, output_desc]: ranges::zip_view(output_tensors, output_descs)) {
            outputs.push_back({output_tensor, output_desc});
        }
        return outputs;
    };
698
    RECORD_EVENT(OpExecuteEvent, apply_id);
699
    // Begin profiling operator
700 701 702 703
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
704 705 706
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
707
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
708 709 710
            }
        }
    }
711 712 713 714 715 716 717 718 719
    for (auto* input: cmd.inputs) {
        auto input_id = input->id;
        RECORD_EVENT(OpInputEvent, input_id);
        RECORD_EVENT(TensorUsageEvent, input_id);
        RECORD_EVENT(OpInputFinishEvent, input_id);
    }
    // Fused by command buffer. @see: CommandBuffer::fuse_del
    // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
    // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
720
    for (auto* del : cmd.dels) {
721 722 723 724
        // refcnt --, owners: [tensor_inputs]
        // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
        uint64_t del_id = del->id;
        RECORD_EVENT(OpDelEvent, del_id);
725
        free(del);
726
        RECORD_EVENT(OpDelFinishEvent, del_id);
727
    }
728 729 730 731 732
    // Before wait
    //TODO: split operator wait and execute so that OpWait could be corrected recorded.
    // Before execute
    for (auto&& [device, kernel_id]: kernels) {
        RECORD_EVENT(KernelExecuteEvent, apply_id, kernel_id, Timer::record_event(device));
733 734 735
    }
    // Apply op
    // Here std::move is REQUIRED for removing duplicated references.
736
    auto outputs = apply_on_physical_tensor(apply_on_physical_tensor, *cmd.op, inputs);
737
    // After execute
738 739
    for (auto&& [device, kernel_id]: kernels) {
        RECORD_EVENT(KernelExecuteFinishEvent, apply_id, kernel_id, Timer::record_event(device));
740 741
    }
    // End profiling operator
742 743
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
744
        auto output = cmd.outputs[i];
745 746 747 748 749 750 751 752
        if (output == nullptr) {
            RECORD_EVENT(OpOutputEvent, 0);
            RECORD_EVENT(OpOutputFinishEvent, 0);
        } else if (output->ptr != nullptr) {
            RECORD_EVENT(OpOutputEvent, output->id);
            RECORD_EVENT(OpOutputFinishEvent, output->id);
        } else {
            RECORD_EVENT(OpOutputEvent, output->id);
753 754
            produce_tensor(output, outputs[i].tensor);
            output->mem_desc = outputs[i].desc;
755 756
            RECORD_EVENT(OpOutputFinishEvent, output->id);
            sample_on_device(output->desc.comp_node, false);
757 758 759 760 761 762 763 764
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
765 766
        for (auto i : outputs) {
            estimate_compute_time += i.tensor->blob()->size();
767 768 769 770 771 772 773 774 775
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
        m_dtr.unpin(cmd.inputs);
    }
776 777
    RECORD_EVENT(OpExecuteFinishEvent, apply_id);
    // End profiling operator
778
}
779

780 781
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
782
    auto& state = get_worker_state();
783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815
    while (!m_apply_stack.empty()) {
        auto& [cmd, idx, recomp] = m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
                RECORD_EVENT(TensorCommandEvent, recomp->id, TensorCommandEvent::ReGen);
            }
        }
        bool regen = false;
        for (size_t i = idx; i < cmd.inputs.size(); i ++) {
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
                regenerate(p); // add ApplyOp to the stack
                regen = true;
                break;
            }
        }
        if (regen) continue;
        // the required input tensors are already in memory
        auto cmd_backup = cmd;
        auto recomp_backup = recomp;
        m_apply_stack.pop();
        do_apply_op(cmd_backup);
        if (recomp_backup) {
            RECORD_EVENT(TensorCommandFinishEvent, recomp_backup->id, TensorCommandFinishEvent::ReGen);
            for (auto o : cmd_backup.outputs) {
                if (o) {
816 817 818 819
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
820
    }
821
    m_applying = false;
822 823
}

824
bool ChannelImpl::auto_evict(size_t force_num) {
825
    auto& state = get_worker_state();
826
    if (!m_dtr.comp_node.valid()) {
827
        return false;
828 829
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
830 831
    size_t flag = false;
    while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) {
832
        RECORD_EVENT(AutoEvictEvent);
833
        sample_on_device(m_dtr.comp_node, false);
834
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling && !force_num);
835 836 837 838 839
        if (!best) {
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
840 841 842 843
            if (force_num > 0) {
                force_num --;
            }
            flag = true;
844 845 846 847
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
848
        }
849
        sample_on_device(m_dtr.comp_node, false);
850
        RECORD_EVENT(AutoEvictFinishEvent);
851
    }
852
    return flag;
853 854
}

855 856 857
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
    for (auto* user: users) {
858 859 860
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
        for (auto* output: outputs) {
861 862 863 864
        // When a `ComputePath` is detach from it's input,
        // there is no need to reserve it,
        // so we detach all output of this path
        // to decrease it's `ref_cnt` to zero.
865 866 867 868 869
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
870 871 872
            for (auto* input: inputs) {
                input->ref_cnt --;
            }
873
        }
874
        // now user is dead
875
    }
876
    mgb_assert(dest->users.empty(), "ComputePath leaking");
877 878
}

879 880 881 882
bool ChannelImpl::check_available() {
    return !m_closed;
}

883 884 885 886 887 888 889 890
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    m_buffer.flush();
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
    RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
    bool require_host = prop == TensorProp::HostValue;
891 892 893 894 895 896 897 898 899 900
    auto host_available = [&]{
        return info->ptr && info->ptr->value_fetched();
    };
    if (require_host && !host_available()) {
        // avoid dead lock
        lock.unlock();
        m_buffer.enqueue(GetValue{info});
        m_buffer.flush();
        lock.lock();
    }
901 902
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
903
        return require_host ? host_available() : static_cast<bool>(info->ptr);
904 905
    });
    RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr);
906
    m_waitee = nullptr;
907 908 909 910 911 912 913
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
        RECORD_EVENT(TensorNotifyPropEvent, info->id);
        m_cv.notify_all();
914
    }
915 916 917 918 919 920 921
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
    for (auto* handle: m_valid_handle) {
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
922
    }
923
    return valid_tensors;
924 925
}

926
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
927 928 929 930 931 932 933 934 935 936 937
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
            if (!evict_suc) return false;
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
938 939
    reserve_size(x->size());
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
940 941 942 943 944 945
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
        while (!suc) {
            if (!auto_evict(1)) {
                break;
            }
946
            MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
947 948 949 950 951 952 953
            MGB_CATCH(MemAllocError&, { continue; });
            suc = true;
        }
        if (!suc) {
            set_log_level(pre_level);
            mgb_log_warn("reallocating all cuda memory to alleviate fragmentation, the performance may be affected");
            set_log_level(LogLevel::NO_LOG);
954
            BlobManager::inst()->defrag(x->comp_node());
955
            BlobManager::inst()->alloc_direct(x, x->size());
956 957 958 959 960
        }
    });
    set_log_level(pre_level);
}

961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> ChannelImpl::init_output_and_workspace(
        const OpDef& def,
        SmallVector<TensorPtr> inputs,
        SmallVector<MemoryDesc> inputs_mem_desc) {

    auto [outputs_desc, workspaces_desc] = OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc);
    if (!outputs_desc.size()) {
        // failed to infer memplan
        return {{}, {}, {}};
    }
    // refine storage id to make it unique
    for (auto&& desc : outputs_desc) {
        if (desc.id->is_sys_alloc()) {
            // TODO: there may be some outputs sharing the same storage id
            desc.id->id = ++ m_storage_id;
        }
    }
978
    auto& state = get_worker_state();
979 980 981 982 983
    auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
        SmallVector<TensorPtr> tensors;
        for (size_t i = 0; i < desc.size(); i ++) {
            if (desc[i].id->is_sys_alloc()) {
                tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
984
                if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) {
985
                    alloc_tensor_with_evict(tensors.back()->blob().get());
986
                }
987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001
            } else if (desc[i].id->is_from_other()) {
                for (size_t j = 0; j < inputs_mem_desc.size();j ++) {
                    if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
                        tensors.push_back(inputs[j]->sub(desc[i].offset, desc[i].layout));
                        break;
                    }
                }
            } else if (desc[i].id->is_device_ptr()) {
                tensors.push_back(desc[i].id->ptr);
            } else {
                mgb_assert(0, "not implemented");
            }
        }
        return tensors;
    };
1002

1003 1004 1005
    return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)};
}

1006
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
1007 1008
    using namespace ranges;
    using namespace ranges::views;
1009
    auto& state = get_worker_state();
1010
    auto& options = state.options;
1011
    //TODO: remove std::visit for support osx 10.12
1012 1013
    auto cmd_visitor = [&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
1014
            if constexpr (std::is_same_v<T, Put>) {
1015
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Put);
1016 1017
                auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
                produce_tensor(cmd.dest, std::move(value));
1018 1019
                RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put);
                sample_on_device(cmd.dest->desc.comp_node, false);
1020
            } else if constexpr (std::is_same_v<T, ApplyOp>) {
1021 1022 1023 1024 1025 1026 1027 1028 1029
                for (auto& i : cmd.inputs) {
                    if (i->invalid) {
                        MGB_LOCK_GUARD(m_mutex);
                        for (auto& i : cmd.outputs) {
                            i->invalid = true;
                        }
                        return;
                    }
                }
1030 1031
                m_apply_stack.push({cmd, 0, nullptr});
                flush_apply_stack();
1032 1033 1034
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
1035 1036
                        continue;
                    }
1037
                    if (state.options.enable_dtr_auto_drop) {
1038
                        output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1039 1040
                    }
                }
1041 1042 1043 1044 1045 1046
                if (state.options.enable_drop && state.options.record_computing_path) {
                    auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                        auto& input = std::get<0>(tuple2);
                        auto& output = std::get<1>(tuple2);
                        if (!input->ptr || !output->ptr) {
                            return false;
1047
                        }
1048 1049
                        return input->ptr->blob()->storage() == output->ptr->blob()->storage();
                    };
1050 1051 1052 1053 1054 1055 1056
                    // FIXME: do not use opname as identifier
                    auto get_name = [](const OpDef& opdef) {
                        if (auto attr = opdef.try_cast_final<OprAttr>()) {
                            return attr->type.c_str();
                        }
                        return opdef.dyn_typeinfo()->name;
                    };
1057 1058 1059 1060 1061 1062 1063

                    auto is_cross_cn = [comp_node=m_dtr.comp_node](TensorInfo* info){
                        return info->desc.comp_node != comp_node;
                    };

                    bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                    bool inplace = any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);
1064

1065 1066
                    if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                        TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs);
1067
                        size_t detach_cnt = 0;
1068 1069 1070 1071 1072 1073 1074
                        if (!strcmp(get_name(*cmd.op), "BatchNorm") && cmd.outputs.size() == 5) {
                            cmd.outputs[0]->detach_producer(); // detach running_mean
                            cmd.outputs[1]->detach_producer(); // detach running_var
                            for (auto input : cmd.inputs) {
                                input->ref_cnt -= 2;
                            }
                        }
1075
                        for (auto output : cmd.outputs) {
1076
                            if (output->producer && !output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
1077 1078 1079 1080 1081 1082 1083 1084
                                output->detach_producer();
                                detach_cnt ++;
                            }
                        }
                        for (auto input : cmd.inputs) {
                            input->ref_cnt -= detach_cnt;
                        }
                    }
1085 1086
                }
            } else if constexpr (std::is_same_v<T, Del>) {
1087 1088 1089
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Del);
                CompNode device = cmd.dest->desc.comp_node;
                uint64_t tensor_id = cmd.dest->id;
1090
                free(cmd.dest);
1091 1092
                RECORD_EVENT(TensorCommandFinishEvent, tensor_id, TensorCommandFinishEvent::Del);
                sample_on_device(device, false);
1093
            } else if constexpr (std::is_same_v<T, GetValue>) {
1094
                if (cmd.dest->invalid) return;
1095
                imperative_log_profile_begin("GetValue");
1096 1097 1098
                if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                    regenerate(cmd.dest);
                }
1099 1100
                cmd.dest->ptr->fetch_value();
                MGB_LOCK_GUARD(m_mutex);
1101
                notify_tensor_unsafe(cmd.dest);
1102
                imperative_log_profile_end("GetValue");
1103
            } else if constexpr (std::is_same_v<T, SwapIn>) {
1104
                if (cmd.dest->invalid) return;
1105
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapIn);
1106
                produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
1107 1108
                RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapIn);
                sample_on_device(cmd.dest->desc.comp_node, false);
1109
            } else if constexpr (std::is_same_v<T, SwapOut>) {
1110
                if (cmd.dest->invalid) return;
1111
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapOut);
1112
                cmd.dest->h_value = cmd.dest->ptr->get_value();
1113 1114
                if (cmd.dest->evict_type == EvictType::NONE) {
                    cmd.dest->evict_type = EvictType::SWAP;
1115 1116
                    cmd.dest->status = TensorInfo::Swapped;
                    release_tensor(cmd.dest);
1117
                }
1118 1119
                RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapOut);
                sample_on_device(cmd.dest->desc.comp_node, false);
1120
            } else if constexpr (std::is_same_v<T, Drop>) {
1121
                if (cmd.dest->invalid) return;
1122
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Drop);
1123
                do_drop(cmd.dest, true);
1124
                RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Drop);
1125
            } else if constexpr (std::is_same_v<T, SetOption>) {
1126
                options.set_option(cmd.key, cmd.value);
1127
            } else if constexpr (std::is_same_v<T, StartProfile>) {
1128
                RECORD_EVENT(StartProfileEvent);
1129
                CompNode::sync_all();
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142
                for (auto* info: cmd.capture_tensors) {
                    RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                    if (info->status == TensorInfo::Produced) {
                        // TODO: handle swap/drop
                        RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
                    }
                }
                CompNode::foreach([&](CompNode device){
                    if (Profiler::get_option("sample_rate", 0)) {
                        sample_on_device(device, true);
                    }
                });
                RECORD_EVENT(StartProfileFinishEvent);
1143
            } else if constexpr (std::is_same_v<T, StopProfile>) {
1144 1145 1146 1147 1148 1149 1150
                RECORD_EVENT(StopProfileEvent);
                for (auto* info: cmd.escape_tensors) {
                    bool has_value = info->status == TensorInfo::Produced;
                    if (has_value) {
                        RECORD_EVENT(TensorReleaseEvent, info->id);
                    }
                    RECORD_EVENT(TensorEraseEvent, info->id);
1151
                }
1152 1153 1154
                CompNode::foreach([&](CompNode device){
                    if (Profiler::get_option("sample_rate", 0)) {
                        sample_on_device(device, true);
1155
                    }
1156 1157
                });
                RECORD_EVENT(StopProfileFinishEvent);
1158
            } else if constexpr (std::is_same_v<T, PushScope>) {
1159
                RECORD_EVENT(ScopeEvent, cmd.scope_name);
1160
            } else if constexpr (std::is_same_v<T, PopScope>) {
1161
                RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
1162
            } else {
1163
                static_assert(!std::is_same_v<T, T>);
1164
            }
1165
    };
1166
    std::visit([&](const auto& cmd){
1167
        using T = std::decay_t<decltype(cmd)>;
1168
        if (!options.catch_worker_execption) {
1169 1170 1171 1172 1173
            cmd_visitor(cmd);
            return;
        }
        try {
            cmd_visitor(cmd);
1174 1175
        } catch (...) {
            MGB_LOCK_GUARD(m_mutex);
1176 1177 1178 1179 1180 1181 1182
            if constexpr (std::is_same_v<T, ApplyOp>) {
                for (auto oup : cmd.outputs) {
                    oup->invalid = true;
                }
            } else if constexpr (std::is_same_v<T, Put>) {
                cmd.dest->invalid = true;
            }
1183
            m_worker_exc = std::current_exception();
1184 1185 1186 1187
            RECORD_EVENT(WorkerExceptionEvent);
            if (m_waitee) {
                notify_tensor_unsafe(m_waitee);
            }
1188
        }
1189
    }, icmd.second);
1190 1191 1192 1193
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1194 1195
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1196 1197
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1198 1199 1200 1201 1202
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1203 1204
    }
}
1205 1206 1207 1208 1209

void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
    if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
        return;
    }
1210
    // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
1211 1212 1213 1214 1215
    m_commands.push_back(std::move(cmd));
    auto flush_pos = flush_pos_for(m_commands.back());
    flush(flush_pos);
}

1216 1217 1218 1219
void ChannelImpl::CommandBuffer::flush() {
    flush(m_commands.end());
}

1220 1221
void ChannelImpl::CommandBuffer::flush(Handle pos) {
    for (auto iter = m_commands.begin(); iter != pos; ++iter) {
1222 1223 1224 1225
        if (Profiler::is_profiling()) {
            mgb_log_debug("%s Flushed", to_string(*iter).c_str());
        }
        m_owner->m_worker.add_task(IdentifiedCommand{Profiler::next_id(), std::move(*iter)});
1226 1227 1228 1229 1230
    }
    m_commands.erase(m_commands.begin(), pos);
}

auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
1231
    auto& state = m_owner->get_channel_state();
1232
    return std::visit([this, &state](const auto& cmd) {
1233 1234 1235 1236 1237 1238 1239
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, ApplyOp>) {
            auto* op_type = cmd.op->dyn_typeinfo();
            if (op_type == RemoteRecv::typeinfo() ||
                op_type == RemoteSend::typeinfo() ||
                op_type == CollectiveComm::typeinfo() ||
                op_type == opr::InputCallback::typeinfo() ||
1240
                op_type == opr::OutputCallback::typeinfo()) {
1241 1242 1243 1244 1245
                return m_commands.end();
            }
        } else if constexpr (std::is_same_v<T, GetValue>) {
            return m_commands.end();
        }
1246
        size_t buffer_length = state.options.buffer_length;
1247 1248
        if (m_commands.size() > buffer_length) {
            return m_commands.begin() + (m_commands.size() - buffer_length);
1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271
        }
        return m_commands.begin();
    }, cmd);
}

/**
 * 1. Find ApplyOp(dest) in buffered commands
 * 2. Check if there are other usages between ApplyOp and Del, return false if not
 * 3. Fuse Del into ApplyOp, return true
 */
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
    auto* dest = cmd.dest;
    // TODO: eliminate Puts
    auto begin = m_commands.begin(), end = m_commands.end();
    auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
        if (auto* apply = std::get_if<ApplyOp>(&cmd)) {
            return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
        }
        return false;
    });
    if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
        return false;
    }
1272
    // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318
    std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
    return true;
}

auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
        -> Handle {
    auto found = range[1];
    for (auto iter = range[0]; iter != range[1]; ++iter) {
        std::visit([&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
                               dest) > 0) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, GetValue>) {
                if (cmd.dest == dest) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, SwapIn> ||
                    std::is_same_v<T, SwapOut> ||
                    std::is_same_v<T, Drop>) {
                //TODO: ignore swap-like commands, just remove them from buffer
                if (cmd.dest == dest) {
                    found = iter;
                }
            }
        }, *iter);
    };
    return found;
}

auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
        -> Handle {
    return std::find_if(range[0], range[1], [dest](auto& cmd) {
        return std::visit([dest](const auto& cmd){
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
            } else if constexpr (std::is_same_v<T, Put>) {
                return cmd.dest == dest;
            }
            return false;
        }, cmd);
    });
}
1319

1320
void ChannelImpl::start_profile() {
1321
    MGB_LOCK_GUARD(m_spin);
1322
    mgb_assert(check_available(), "Channel already closed");
1323 1324 1325 1326
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
        m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
    }
1327 1328
}

1329
void ChannelImpl::stop_profile() {
1330
    MGB_LOCK_GUARD(m_spin);
1331
    mgb_assert(check_available(), "Channel already closed");
1332
    m_buffer.flush();
1333 1334 1335 1336
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
        m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
    }
1337 1338 1339
}

void ChannelImpl::push_scope(std::string name) {
1340
    MGB_LOCK_GUARD(m_spin);
1341
    mgb_assert(check_available(), "Channel already closed");
1342
    auto& state = get_channel_state();
1343
    state.scopes.push(name);
1344
    RECORD_EVENT(ScopeEvent, name);
1345
    m_buffer.enqueue(PushScope{name});
1346 1347 1348
}

void ChannelImpl::pop_scope(std::string name) {
1349
    MGB_LOCK_GUARD(m_spin);
1350
    mgb_assert(check_available(), "Channel already closed");
1351
    auto& state = get_channel_state();
1352
    state.scopes.pop(name);
1353
    RECORD_EVENT(ScopeFinishEvent, name);
1354
    m_buffer.enqueue(PopScope{name});
1355 1356
}

1357 1358 1359 1360 1361 1362 1363 1364
void ChannelImpl::assert_in_channel() {
    mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread");
}

void ChannelImpl::assert_in_worker() {
    mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread");
}

1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377
void ChannelImpl::sample_on_device(CompNode device, bool force) {
    if (!force) {
        thread_local int last_sample_id = 0;
        int sample_rate = Profiler::is_profiling() ? Profiler::get_option("sample_rate", 0) : 0;
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
    RECORD_EVENT(SampleDeviceEvent, device);
    auto [total, free] = device.get_mem_status_bytes();
    RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
}

1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
    }
}

void ChannelImpl::DynamicSublinear::unpin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->unpin();
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

1433
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(bool enable_dtr_sqrt_sampling=false) {
1434 1435
    double min_msps = -1;
    TensorInfo* best = nullptr;
1436 1437 1438 1439 1440 1441
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
        while (sz * sz <= candidates.size()) sz ++;
    } else {
        sz = candidates.size();
    }
1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453
    for (auto i : candidates) {
        if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
            double neighbor_cost = estimate_neighbor_cost(i);
            size_t begin_ptr = reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(begin_ptr, begin_ptr + i->ptr->blob()->size());
            double free_mem = side_info.first + side_info.second;
            double msps = i->eval_func(neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1454
        if (--sz == 0) break;
1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491
    }
    return best;
}

void ChannelImpl::DynamicSublinear::merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y) {
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(std::shared_ptr<DsuNode>& x) {
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
    candidates.insert(ptr);
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
    candidates.erase(ptr);
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}