interpreter_impl.cpp 54.8 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21
#include "megbrain/imperative/ops/utility.h"
22 23
#include "megbrain/imperative/utils/to_string.h"

24
#include "../blob_manager_impl.h"
25 26 27
#include "../event_pool.h"
#include "../op_trait.h"

28 29 30 31 32
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

33 34 35 36 37 38 39 40 41 42
namespace {
    auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
        SmallVector<uint64_t> tid;
        for (auto* ptinfo: tinfo) {
            tid.push_back(ptinfo->id);
        }
        return tid;
    };
}

43 44 45 46
namespace mgb {
    using namespace profiler;
}

47 48 49 50 51
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
52 53 54 55 56 57 58

namespace mgb {

/**
 * USAGE
 *
 *   header:
59
 *     namespace mgb { void imperative_log_profile(const char* message); }
60 61 62 63 64
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
65
SYMBOL_EXPORT
66
void imperative_log_profile_begin(const char* message) {
67
    MGB_RECORD_EVENT(CustomEvent, std::string{message});
68 69
}

70
SYMBOL_EXPORT
71
void imperative_log_profile_end(const char* message) {
72
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
73 74
}

75
SYMBOL_EXPORT
76 77 78 79 80 81 82
void imperative_log_profile(const char* message){
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

}

83 84 85 86 87 88 89 90 91 92 93 94 95 96
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

97 98 99 100 101 102 103 104 105 106
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
    OpDef::set_allocator([&](CompNode device, size_t size) {
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
    });
}

107
// Do not use m_xxx_state directly
108 109 110
#define m_channel_state
#define m_worker_state

111 112 113 114 115 116 117 118 119
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

120
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
121
    MGB_LOCK_GUARD(m_spin);
122
    mgb_assert(check_available(), "Channel already closed");
123
    auto& state = get_channel_state();
124
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
125 126 127 128 129
    auto info = put_impl(value, no_cache);
    return info;
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
130 131 132 133 134
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
135
    auto info = alloc();
136
    init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
137
    info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
138
    info->h_value = value;
139
    m_buffer.enqueue(Put{info, value, no_cache});
140
    if (m_async_level == 0) {
141
        sync_impl();
142 143
        info->desc.comp_node.sync();
    }
144 145 146
    return info;
}

147
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
148
    MGB_LOCK_GUARD(m_spin);
149
    mgb_assert(check_available(), "Channel already closed");
150 151 152 153
    return put_impl(data, hvalue);
}
TensorInfo* ChannelImpl::put_impl(const DeviceTensorND& data, const HostTensorND& hvalue) {
    auto& state = get_channel_state();
154
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
M
Megvii Engine Team 已提交
155
    auto info = alloc();
156
    MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
157
    init(info, {data.layout(), data.comp_node()});
158
    info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
159
    info->ptr = Tensor::make(data, hvalue);
160
    MGB_RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
161
    info->status = TensorInfo::Produced;
162
    MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
M
Megvii Engine Team 已提交
163 164 165
    return info;
}

166
void ChannelImpl::del(Handle handle) {
167
    MGB_LOCK_GUARD(m_spin);
168 169 170
    if (!check_available()){
        return;
    }
171 172 173 174
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
175 176 177 178
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
    m_buffer.enqueue(Del{info});
179 180
}

181
void ChannelImpl::swap_in(Handle handle) {
182
    MGB_LOCK_GUARD(m_spin);
183
    mgb_assert(check_available(), "Channel already closed");
184 185
    auto& state = get_channel_state();
    if (state.options.enable_swap) {
186 187
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
188 189
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapIn{info});
190 191 192
    }
}

193
void ChannelImpl::swap_out(Handle handle) {
194
    MGB_LOCK_GUARD(m_spin);
195
    mgb_assert(check_available(), "Channel already closed");
196 197
    auto& state = get_channel_state();
    if (state.options.enable_swap) {
198 199
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
200 201
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapOut{info});
202 203 204
    }
}

205
void ChannelImpl::drop(Handle handle) {
206
    MGB_LOCK_GUARD(m_spin);
207
    mgb_assert(check_available(), "Channel already closed");
208 209
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
210 211
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
212 213
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(Drop{info});
214 215 216
    }
}

217
void ChannelImpl::dispatch_default_cpu(
218
        std::shared_ptr<OpDef> op,
219 220 221
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
222
    auto& state = get_channel_state();
223 224

    auto name = op->trait()->make_name(*op);
225
    auto _ = StackManager::Guard(name, &state.stack_manager);
226

227
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
228
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
229

230 231 232
    SmallVector<DeviceTensorND> input_tensornds;
    input_tensornds.reserve(input_descs.size());
    CompNode output_cn;
233 234
    {
        MGB_LOCK_GUARD(m_mutex);
235
        for (auto&& info : input_infos) {
236
            auto input_cn = info->desc.comp_node;
237
            if (!output_cn.valid()) {
238 239 240 241 242 243 244
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
                input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
245
            } else {
246
                // It's OK for SwapOut. We assign h_value before drop ptr
247 248
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
249 250 251 252 253 254 255 256 257 258 259 260 261 262
            }
        }
    }

    outputs->reserve(output_descs.size());
    SmallVector<DeviceTensorND> output_tensornds;
    output_tensornds.reserve(output_descs.size());
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
        output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
    }

263
    uint64_t op_id = Profiler::next_id();
264

265 266 267 268 269 270 271
    OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);

    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());
    for (auto&& tensornd : output_tensornds) {
        HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
            .proxy_to_comp_node(output_cn);
272
        // use `put` for consistency
273
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
274
        mgb_assert(info->desc.layout.ndim != 0);
275 276 277
        output_infos.push_back(info);
        outputs->push_back(info);
    }
278 279 280 281 282 283 284 285
    auto op_info_getter = [op]{
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
        for (auto&& [key, value]: props) {
            op_info[key] = value;
        }
        return op_info;
    };
286
    MGB_RECORD_EVENT(OpDispatchEvent, op_id, name, op_info_getter,
287 288
                 tinfo_to_tid(input_infos), tinfo_to_tid(output_infos),
                 state.stack_manager.dump());
289
}
290

291 292 293 294 295
void ChannelImpl::dispatch_kernel(
        std::shared_ptr<OpDef> op,
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
296
    auto& state = get_channel_state();
297 298 299
    auto& options = state.options;

    auto name = op->trait()->make_name(*op);
300
    auto _  = StackManager::Guard{name, &state.stack_manager};
301

302
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
303
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
304

305
    ApplyOp cmd{Profiler::next_id(), std::move(op)};
306
    cmd.inputs = std::move(input_infos);
307
    cmd.outputs.reserve(output_descs.size());
308
    outputs->reserve(output_descs.size());
309 310
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
311
        auto info = alloc();
312
        init(info, desc);
313 314 315 316 317
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
                .proxy_to_comp_node(desc.comp_node);
        }
318
        cmd.outputs.push_back(info);
319
        outputs->push_back(info);
320
    }
321 322 323 324 325 326 327 328
    auto op_info_getter = [op=cmd.op]{
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
        for (auto&& [key, value]: props) {
            op_info[key] = value;
        }
        return op_info;
    };
329
    MGB_RECORD_EVENT(OpDispatchEvent, cmd.id, name, op_info_getter,
330 331
                 tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs),
                 state.stack_manager.dump());
332
    m_buffer.enqueue(std::move(cmd));
333
    if (!validated && options.async_level == 1) {
334
        sync_impl();
335
    } else if (options.async_level == 0) {
336
        sync_impl();
337
        // check device error
338
        for (auto&& oup : *outputs) {
339 340
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
341
        }
342
    }
343 344 345 346 347
}

SmallVector<Handle> ChannelImpl::apply_op(
        std::shared_ptr<OpDef> op,
        const SmallVector<Handle>& inputs) {
348
    MGB_LOCK_GUARD(m_spin);
349
    mgb_assert(check_available(), "Channel already closed");
350 351 352 353 354 355
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
        std::shared_ptr<OpDef> op,
        const SmallVector<Handle>& inputs) {
356
    auto& state = get_channel_state();
357 358 359 360 361 362 363 364 365 366 367 368
    for (auto i : inputs) {
        mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
                "invalid handle: %p", i);
    }
    SmallVector<TensorInfo*> input_infos;
    input_infos.reserve(inputs.size());
    SmallVector<LogicalTensorDesc> input_descs;
    input_descs.reserve(inputs.size());
    {
        MGB_LOCK_GUARD(m_mutex);
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
369
            mgb_assert(!info->invalid, "an input tensor is unusable due to previous error");
370 371 372 373 374 375
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
376
    DispatchMode dispatch_mode = state.options.enable_host_compute
377 378 379
            ? OpDef::decide_dispatch_mode(*op, input_descs)
            : DispatchMode::KERNEL;
    switch (dispatch_mode) {
380 381 382 383 384 385 386 387 388
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
389 390 391
    return outputs;
}

392
HostTensorND ChannelImpl::get_value(Handle handle) {
393
    MGB_LOCK_GUARD(m_spin);
394
    mgb_assert(check_available(), "Channel already closed");
395 396 397
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
398
    // donnot use info->value_fetched, it's unsafe
399
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
400
    return wait_tensor(info, TensorProp::HostValue)->get_value();
401 402
}

403
TensorShape ChannelImpl::get_shape(Handle handle) {
404
    MGB_LOCK_GUARD(m_spin);
405
    mgb_assert(check_available(), "Channel already closed");
406 407 408 409 410 411
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
412
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
413 414 415 416
    mgb_assert(ret.ndim != 0);
    return ret;
}

417
DType ChannelImpl::get_dtype(Handle handle) {
418
    MGB_LOCK_GUARD(m_spin);
419
    mgb_assert(check_available(), "Channel already closed");
420 421 422
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
423
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
424 425 426 427 428
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

429
CompNode ChannelImpl::get_device(Handle handle) {
430
    MGB_LOCK_GUARD(m_spin);
431
    mgb_assert(check_available(), "Channel already closed");
432 433 434
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
435
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
436 437 438 439 440
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

441
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
442
    MGB_LOCK_GUARD(m_spin);
443
    mgb_assert(check_available(), "Channel already closed");
444 445 446
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
447
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
448 449 450
}

void ChannelImpl::sync() {
451
    MGB_LOCK_GUARD(m_spin);
452
    mgb_assert(check_available(), "Channel already closed");
453 454 455 456
    sync_impl();
}

void ChannelImpl::sync_impl() {
457
    m_buffer.flush();
458 459 460 461 462 463
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
464
    MGB_LOCK_GUARD(m_spin);
465 466 467 468 469
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
    for (auto* handle: valid_handles) {
470
        del_impl(handle);
471 472 473
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
474
    sync_impl();
475
    m_closed = true;
476 477
}

478
size_t ChannelImpl::get_option(std::string name) {
479
    MGB_LOCK_GUARD(m_spin);
480
    mgb_assert(check_available(), "Channel already closed");
481 482
    auto& state = get_channel_state();
    return state.options.get_option(name);
483 484
}

485
void ChannelImpl::set_option(std::string name, size_t value) {
486
    MGB_LOCK_GUARD(m_spin);
487
    mgb_assert(check_available(), "Channel already closed");
488 489
    auto& state = get_channel_state();
    state.options.set_option(name, value);
490
    m_buffer.enqueue(SetOption{name, value});
491 492 493
}

TensorInfo* ChannelImpl::alloc() {
494
    auto& state = get_channel_state();
495 496 497 498 499 500
    auto info = [this]{
        MGB_LOCK_GUARD(m_mutex);
        return m_pool.alloc();
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
501 502
        size_t tensor_id = state.stack_manager.current()->next_id("tensor");
        info->name = state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
503
    }
504
    return info;
505 506
}

507 508
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
    m_valid_handle.insert(info);
509
    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
510 511
    info->status = TensorInfo::Allocated;
    info->desc = std::move(desc);
512 513 514
    info->mem_desc.layout = info->desc.layout;
    info->mem_desc.cn = info->desc.comp_node;
    info->mem_desc.offset = 0;
515 516
}

517 518 519 520 521 522 523 524 525 526 527 528

void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) {
    if (!ptr->producer) {
        if (user) {
            mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", ptr);
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
529
    ptr->status = TensorInfo::Dropped;
530 531 532
    release_tensor(ptr);
}

533
void ChannelImpl::free(TensorInfo* ptr) {
534 535
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
551
    MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
552
    SmallVector<TensorInfo*> inps;
553 554 555 556 557 558 559 560 561 562 563 564 565
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
566
    MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
567 568 569
}

void ChannelImpl::real_free(TensorInfo* ptr) {
570 571
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
572 573 574 575
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
576 577
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
578
        MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
579
    }
580
    MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
581
    ptr->status = TensorInfo::Deleted;
582
    MGB_LOCK_GUARD(m_mutex);
583 584 585
    m_pool.free(ptr);
}

586
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
587

588 589 590
ChannelImpl::~ChannelImpl() {
    close();
}
591

592
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
593
    auto& state = get_worker_state();
594
    MGB_LOCK_GUARD(m_mutex);
595
    m_dtr.update_used_time(dest);
596
    MGB_RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr());
597 598 599
    // update tensor desc for static infer
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
600
    dest->memory = ptr->blob()->size();
601
    dest->ptr = std::move(ptr);
602
    dest->evict_type = EvictType::NONE;
603
    dest->status = TensorInfo::Produced;
604
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
605 606
        m_dtr.insert_candidate(dest);
    }
607
    notify_tensor_unsafe(dest);
608 609
}

610
void ChannelImpl::release_tensor(TensorInfo* dest) {
611
    MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
612 613 614 615
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
}

616
void ChannelImpl::regenerate(TensorInfo* dest) {
617
    if (dest->evict_type == EvictType::DROP) {
618 619 620
        auto &&path = dest->producer;
        m_apply_stack.push({ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest});
        if (!m_applying) flush_apply_stack();
621
    } else if (dest->evict_type == EvictType::SWAP) {
622
        MGB_RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandKind::ReGen);
623
        produce_tensor(dest, Tensor::make(dest->h_value));
624
        MGB_RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandKind::ReGen);
625 626 627
    }
}

628 629 630
void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
    using namespace ranges;
    using namespace ranges::views;
631
    auto& state = get_worker_state();
632
    bool profiling_device = Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
633
    uint64_t apply_id = cmd.id;
634 635 636 637 638 639
    struct TensorWithDesc {
        TensorPtr tensor;
        MemoryDesc desc;
    };
    SmallVector<TensorWithDesc> inputs;
    inputs.reserve(cmd.inputs.size());
640 641 642
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
643
        // refcnt ++, owners: [i->ptr, tensor_inputs]
644 645
        // tensor_inputs.push_back(i->ptr);
        inputs.push_back({i->ptr, i->mem_desc});
646
    }
647 648 649
    if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
        auto_evict(0);
    }
650 651 652
    auto apply_on_physical_tensor = [&](auto&& self, const OpDef& def, SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> {
        auto apply_functor = [&](std::shared_ptr<OpDef> op, SmallVector<TensorWithDesc> inputs, size_t nr_outputs) -> SmallVector<TensorWithDesc> {
            auto opname = op->trait()->make_name(*op);
653
            imperative_log_profile_begin(opname.c_str());
654
            auto outputs = self(self, *op, inputs);
655
            imperative_log_profile_end(opname.c_str());
656 657 658 659 660 661 662 663 664 665
            return outputs;
        };
        auto const_functor = [&](TensorPtr value) -> TensorWithDesc {
            return {value, MemoryDesc{value->layout(), 0, value->comp_node(), StorageIdentifier::make()}};
        };
        if (def.trait()->make_forward_graph) {
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
            for (auto&& input: inputs) {
                input_descs.push_back({{{}, input.tensor->dtype()}, input.tensor->comp_node()});
666
            }
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
            auto outputs = forward_graph.apply(inputs, apply_functor, const_functor);
            return outputs;
        }
        SmallVector<TensorPtr> input_tensors;
        SmallVector<MemoryDesc> input_descs;
        for (auto&& input: inputs) {
            input_tensors.push_back(input.tensor);
            input_descs.push_back(input.desc);
        }
        auto [output_descs, output_tensors, workspaces] = init_output_and_workspace(def, input_tensors, input_descs);
        if (!output_descs.empty()) {
            OpDef::execute(def, input_tensors, output_tensors, workspaces);
        } else {
            output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors);
            for (auto&& output_tensor: output_tensors) {
                output_descs.push_back(MemoryDesc{output_tensor->layout(), 0, output_tensor->comp_node(), StorageIdentifier::make()});
684 685
            }
        }
686 687 688 689 690 691
        SmallVector<TensorWithDesc> outputs;
        for (auto&& [output_tensor, output_desc]: ranges::zip_view(output_tensors, output_descs)) {
            outputs.push_back({output_tensor, output_desc});
        }
        return outputs;
    };
692
    MGB_RECORD_EVENT(OpExecuteEvent, apply_id);
693
    // Begin profiling operator
694 695 696 697
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
698 699 700
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
701
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
702 703 704
            }
        }
    }
705 706
    for (auto* input: cmd.inputs) {
        auto input_id = input->id;
707 708 709
        MGB_RECORD_EVENT(OpInputEvent, input_id);
        MGB_RECORD_EVENT(TensorUsageEvent, input_id);
        MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
710 711 712 713
    }
    // Fused by command buffer. @see: CommandBuffer::fuse_del
    // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
    // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
714
    for (auto* del : cmd.dels) {
715 716 717
        // refcnt --, owners: [tensor_inputs]
        // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
        uint64_t del_id = del->id;
718
        MGB_RECORD_EVENT(TensorCommandEvent, del_id, TensorCommandKind::Del);
719
        free(del);
720
        MGB_RECORD_EVENT(TensorCommandFinishEvent, del_id, TensorCommandKind::Del);
721
    }
722 723 724 725
    // Before wait
    //TODO: split operator wait and execute so that OpWait could be corrected recorded.
    // Before execute
    for (auto&& [device, kernel_id]: kernels) {
726
        MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
727
        MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
728 729 730
    }
    // Apply op
    // Here std::move is REQUIRED for removing duplicated references.
731
    auto outputs = apply_on_physical_tensor(apply_on_physical_tensor, *cmd.op, inputs);
732
    // After execute
733
    for (auto&& [device, kernel_id]: kernels) {
734
        MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
735
        MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
736 737
    }
    // End profiling operator
738 739
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
740
        auto output = cmd.outputs[i];
741
        if (output == nullptr) {
742 743
            MGB_RECORD_EVENT(OpOutputEvent, 0);
            MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
744
        } else if (output->ptr != nullptr) {
745 746
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
747
        } else {
748
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
749 750
            produce_tensor(output, outputs[i].tensor);
            output->mem_desc = outputs[i].desc;
751
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
752
            sample_on_device(output->desc.comp_node, false);
753 754 755 756 757 758 759 760
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
761 762
        for (auto i : outputs) {
            estimate_compute_time += i.tensor->blob()->size();
763 764 765 766 767 768 769 770 771
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
        m_dtr.unpin(cmd.inputs);
    }
772
    MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id);
773
    // End profiling operator
774
}
775

776 777
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
778
    auto& state = get_worker_state();
779 780 781 782 783 784 785
    while (!m_apply_stack.empty()) {
        auto& [cmd, idx, recomp] = m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
786
                MGB_RECORD_EVENT(TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808
            }
        }
        bool regen = false;
        for (size_t i = idx; i < cmd.inputs.size(); i ++) {
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
                regenerate(p); // add ApplyOp to the stack
                regen = true;
                break;
            }
        }
        if (regen) continue;
        // the required input tensors are already in memory
        auto cmd_backup = cmd;
        auto recomp_backup = recomp;
        m_apply_stack.pop();
        do_apply_op(cmd_backup);
        if (recomp_backup) {
809
            MGB_RECORD_EVENT(TensorCommandFinishEvent, recomp_backup->id, TensorCommandKind::ReGen);
810 811
            for (auto o : cmd_backup.outputs) {
                if (o) {
812 813 814 815
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
816
    }
817
    m_applying = false;
818 819
}

820
bool ChannelImpl::auto_evict(size_t force_num) {
821
    auto& state = get_worker_state();
822
    if (!m_dtr.comp_node.valid()) {
823
        return false;
824 825
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
826 827
    size_t flag = false;
    while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) {
828
        MGB_RECORD_EVENT(AutoEvictEvent);
829
        sample_on_device(m_dtr.comp_node, false);
830
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling && !force_num);
831 832 833 834 835
        if (!best) {
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
836 837 838 839
            if (force_num > 0) {
                force_num --;
            }
            flag = true;
840 841 842 843
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
844
        }
845
        sample_on_device(m_dtr.comp_node, false);
846
        MGB_RECORD_EVENT(AutoEvictFinishEvent);
847
    }
848
    return flag;
849 850
}

851 852 853
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
    for (auto* user: users) {
854 855 856
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
        for (auto* output: outputs) {
857 858 859 860
        // When a `ComputePath` is detach from it's input,
        // there is no need to reserve it,
        // so we detach all output of this path
        // to decrease it's `ref_cnt` to zero.
861 862 863 864 865
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
866 867 868
            for (auto* input: inputs) {
                input->ref_cnt --;
            }
869
        }
870
        // now user is dead
871
    }
872
    mgb_assert(dest->users.empty(), "ComputePath leaking");
873 874
}

875 876 877 878
bool ChannelImpl::check_available() {
    return !m_closed;
}

879 880 881 882 883 884
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    m_buffer.flush();
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
885
    MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
886
    bool require_host = prop == TensorProp::HostValue;
887 888 889 890 891 892 893 894 895 896
    auto host_available = [&]{
        return info->ptr && info->ptr->value_fetched();
    };
    if (require_host && !host_available()) {
        // avoid dead lock
        lock.unlock();
        m_buffer.enqueue(GetValue{info});
        m_buffer.flush();
        lock.lock();
    }
897 898
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
899
        return require_host ? host_available() : static_cast<bool>(info->ptr);
900
    });
901
    MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
902
    m_waitee = nullptr;
903 904 905 906 907
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
908
        MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
909
        m_cv.notify_all();
910
    }
911 912 913 914 915 916 917
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
    for (auto* handle: m_valid_handle) {
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
918
    }
919
    return valid_tensors;
920 921
}

922
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
923 924 925 926 927 928 929 930 931 932 933
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
            if (!evict_suc) return false;
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
934 935
    reserve_size(x->size());
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
936 937 938 939 940 941
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
        while (!suc) {
            if (!auto_evict(1)) {
                break;
            }
942
            MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
943 944 945 946 947 948 949
            MGB_CATCH(MemAllocError&, { continue; });
            suc = true;
        }
        if (!suc) {
            set_log_level(pre_level);
            mgb_log_warn("reallocating all cuda memory to alleviate fragmentation, the performance may be affected");
            set_log_level(LogLevel::NO_LOG);
950
            BlobManager::inst()->defrag(x->comp_node());
951
            BlobManager::inst()->alloc_direct(x, x->size());
952 953 954 955 956
        }
    });
    set_log_level(pre_level);
}

957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> ChannelImpl::init_output_and_workspace(
        const OpDef& def,
        SmallVector<TensorPtr> inputs,
        SmallVector<MemoryDesc> inputs_mem_desc) {

    auto [outputs_desc, workspaces_desc] = OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc);
    if (!outputs_desc.size()) {
        // failed to infer memplan
        return {{}, {}, {}};
    }
    // refine storage id to make it unique
    for (auto&& desc : outputs_desc) {
        if (desc.id->is_sys_alloc()) {
            // TODO: there may be some outputs sharing the same storage id
            desc.id->id = ++ m_storage_id;
        }
    }
974
    auto& state = get_worker_state();
975 976 977 978 979
    auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
        SmallVector<TensorPtr> tensors;
        for (size_t i = 0; i < desc.size(); i ++) {
            if (desc[i].id->is_sys_alloc()) {
                tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
980
                if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) {
981
                    alloc_tensor_with_evict(tensors.back()->blob().get());
982
                }
983 984 985 986 987 988 989 990 991 992 993 994 995 996 997
            } else if (desc[i].id->is_from_other()) {
                for (size_t j = 0; j < inputs_mem_desc.size();j ++) {
                    if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
                        tensors.push_back(inputs[j]->sub(desc[i].offset, desc[i].layout));
                        break;
                    }
                }
            } else if (desc[i].id->is_device_ptr()) {
                tensors.push_back(desc[i].id->ptr);
            } else {
                mgb_assert(0, "not implemented");
            }
        }
        return tensors;
    };
998

999 1000 1001
    return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)};
}

1002
void ChannelImpl::process_one_task(Command& icmd) {
1003 1004
    using namespace ranges;
    using namespace ranges::views;
1005
    auto& state = get_worker_state();
1006
    auto& options = state.options;
1007
    //TODO: remove std::visit for support osx 10.12
1008 1009
    auto cmd_visitor = [&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
1010
            if constexpr (std::is_same_v<T, Put>) {
1011
                MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
1012
                MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(cmd.value.comp_node()));
1013
                auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
1014
                MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(cmd.value.comp_node()));
1015
                produce_tensor(cmd.dest, std::move(value));
1016
                MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
1017
                sample_on_device(cmd.dest->desc.comp_node, false);
1018
            } else if constexpr (std::is_same_v<T, ApplyOp>) {
1019 1020 1021 1022 1023 1024 1025 1026 1027
                for (auto& i : cmd.inputs) {
                    if (i->invalid) {
                        MGB_LOCK_GUARD(m_mutex);
                        for (auto& i : cmd.outputs) {
                            i->invalid = true;
                        }
                        return;
                    }
                }
1028 1029
                m_apply_stack.push({cmd, 0, nullptr});
                flush_apply_stack();
1030 1031 1032
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
1033 1034
                        continue;
                    }
1035
                    if (state.options.enable_dtr_auto_drop) {
1036
                        output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1037 1038
                    }
                }
1039 1040 1041 1042 1043 1044
                if (state.options.enable_drop && state.options.record_computing_path) {
                    auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                        auto& input = std::get<0>(tuple2);
                        auto& output = std::get<1>(tuple2);
                        if (!input->ptr || !output->ptr) {
                            return false;
1045
                        }
1046 1047
                        return input->ptr->blob()->storage() == output->ptr->blob()->storage();
                    };
1048 1049 1050 1051 1052 1053 1054
                    // FIXME: do not use opname as identifier
                    auto get_name = [](const OpDef& opdef) {
                        if (auto attr = opdef.try_cast_final<OprAttr>()) {
                            return attr->type.c_str();
                        }
                        return opdef.dyn_typeinfo()->name;
                    };
1055 1056 1057 1058 1059 1060 1061

                    auto is_cross_cn = [comp_node=m_dtr.comp_node](TensorInfo* info){
                        return info->desc.comp_node != comp_node;
                    };

                    bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                    bool inplace = any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);
1062

1063 1064
                    if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                        TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs);
1065
                        size_t detach_cnt = 0;
1066 1067 1068 1069 1070 1071 1072
                        if (!strcmp(get_name(*cmd.op), "BatchNorm") && cmd.outputs.size() == 5) {
                            cmd.outputs[0]->detach_producer(); // detach running_mean
                            cmd.outputs[1]->detach_producer(); // detach running_var
                            for (auto input : cmd.inputs) {
                                input->ref_cnt -= 2;
                            }
                        }
1073
                        for (auto output : cmd.outputs) {
1074
                            if (output->producer && !output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
1075 1076 1077 1078 1079 1080 1081 1082
                                output->detach_producer();
                                detach_cnt ++;
                            }
                        }
                        for (auto input : cmd.inputs) {
                            input->ref_cnt -= detach_cnt;
                        }
                    }
1083 1084
                }
            } else if constexpr (std::is_same_v<T, Del>) {
1085
                MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
1086 1087
                CompNode device = cmd.dest->desc.comp_node;
                uint64_t tensor_id = cmd.dest->id;
1088
                free(cmd.dest);
1089
                MGB_RECORD_EVENT(TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
1090
                sample_on_device(device, false);
1091
            } else if constexpr (std::is_same_v<T, GetValue>) {
1092
                if (cmd.dest->invalid) return;
1093
                imperative_log_profile_begin("GetValue");
1094 1095 1096
                if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                    regenerate(cmd.dest);
                }
1097 1098
                cmd.dest->ptr->fetch_value();
                MGB_LOCK_GUARD(m_mutex);
1099
                notify_tensor_unsafe(cmd.dest);
1100
                imperative_log_profile_end("GetValue");
1101
            } else if constexpr (std::is_same_v<T, SwapIn>) {
1102
                if (cmd.dest->invalid) return;
1103
                MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::SwapIn);
1104
                produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
1105
                MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::SwapIn);
1106
                sample_on_device(cmd.dest->desc.comp_node, false);
1107
            } else if constexpr (std::is_same_v<T, SwapOut>) {
1108
                if (cmd.dest->invalid) return;
1109
                MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::SwapOut);
1110
                cmd.dest->h_value = cmd.dest->ptr->get_value();
1111 1112
                if (cmd.dest->evict_type == EvictType::NONE) {
                    cmd.dest->evict_type = EvictType::SWAP;
1113 1114
                    cmd.dest->status = TensorInfo::Swapped;
                    release_tensor(cmd.dest);
1115
                }
1116
                MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::SwapOut);
1117
                sample_on_device(cmd.dest->desc.comp_node, false);
1118
            } else if constexpr (std::is_same_v<T, Drop>) {
1119
                if (cmd.dest->invalid) return;
1120
                MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
1121
                do_drop(cmd.dest, true);
1122
                MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
1123
            } else if constexpr (std::is_same_v<T, SetOption>) {
1124
                options.set_option(cmd.key, cmd.value);
1125
            } else if constexpr (std::is_same_v<T, StartProfile>) {
1126
                MGB_RECORD_EVENT(StartProfileEvent);
1127
                CompNode::sync_all();
1128
                for (auto* info: cmd.capture_tensors) {
1129
                    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
1130 1131
                    if (info->status == TensorInfo::Produced) {
                        // TODO: handle swap/drop
1132
                        MGB_RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
1133 1134 1135 1136 1137 1138
                    }
                }
                CompNode::foreach([&](CompNode device){
                    if (Profiler::get_option("sample_rate", 0)) {
                        sample_on_device(device, true);
                    }
1139
                    MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
1140
                });
1141
                MGB_RECORD_EVENT(StartProfileFinishEvent);
1142
            } else if constexpr (std::is_same_v<T, StopProfile>) {
1143
                MGB_RECORD_EVENT(StopProfileEvent);
1144 1145 1146
                for (auto* info: cmd.escape_tensors) {
                    bool has_value = info->status == TensorInfo::Produced;
                    if (has_value) {
1147
                        MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
1148
                    }
1149
                    MGB_RECORD_EVENT(TensorEraseEvent, info->id);
1150
                }
1151 1152 1153
                CompNode::foreach([&](CompNode device){
                    if (Profiler::get_option("sample_rate", 0)) {
                        sample_on_device(device, true);
1154
                    }
1155
                });
1156
                MGB_RECORD_EVENT(StopProfileFinishEvent);
1157
            } else if constexpr (std::is_same_v<T, PushScope>) {
1158
                MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
1159
            } else if constexpr (std::is_same_v<T, PopScope>) {
1160
                MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
1161
            } else {
1162
                static_assert(!std::is_same_v<T, T>);
1163
            }
1164
    };
1165
    std::visit([&](const auto& cmd){
1166
        using T = std::decay_t<decltype(cmd)>;
1167
        if (!options.catch_worker_execption) {
1168 1169 1170 1171 1172
            cmd_visitor(cmd);
            return;
        }
        try {
            cmd_visitor(cmd);
1173 1174
        } catch (...) {
            MGB_LOCK_GUARD(m_mutex);
1175 1176 1177 1178 1179 1180 1181
            if constexpr (std::is_same_v<T, ApplyOp>) {
                for (auto oup : cmd.outputs) {
                    oup->invalid = true;
                }
            } else if constexpr (std::is_same_v<T, Put>) {
                cmd.dest->invalid = true;
            }
1182
            m_worker_exc = std::current_exception();
1183
            MGB_RECORD_EVENT(WorkerExceptionEvent);
1184 1185 1186
            if (m_waitee) {
                notify_tensor_unsafe(m_waitee);
            }
1187
        }
1188
    }, icmd.data);
1189 1190 1191 1192
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1193 1194
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1195 1196
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1197 1198 1199 1200 1201
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1202 1203
    }
}
1204

1205 1206
void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
    auto& state = m_owner->get_channel_state();
1207 1208 1209
    if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
        return;
    }
1210
    // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
1211
    m_commands.push_back({Profiler::next_id(), std::move(cmd), state.stack_manager.dump()});
1212 1213 1214 1215
    auto flush_pos = flush_pos_for(m_commands.back());
    flush(flush_pos);
}

1216 1217 1218 1219
void ChannelImpl::CommandBuffer::flush() {
    flush(m_commands.end());
}

1220 1221
void ChannelImpl::CommandBuffer::flush(Handle pos) {
    for (auto iter = m_commands.begin(); iter != pos; ++iter) {
1222 1223 1224
        if (Profiler::is_profiling()) {
            mgb_log_debug("%s Flushed", to_string(*iter).c_str());
        }
1225
        m_owner->m_worker.add_task(std::move(*iter));
1226 1227 1228 1229 1230
    }
    m_commands.erase(m_commands.begin(), pos);
}

auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
1231
    auto& state = m_owner->get_channel_state();
1232
    return std::visit([this, &state](const auto& cmd) {
1233 1234 1235 1236 1237 1238 1239
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, ApplyOp>) {
            auto* op_type = cmd.op->dyn_typeinfo();
            if (op_type == RemoteRecv::typeinfo() ||
                op_type == RemoteSend::typeinfo() ||
                op_type == CollectiveComm::typeinfo() ||
                op_type == opr::InputCallback::typeinfo() ||
1240
                op_type == opr::OutputCallback::typeinfo()) {
1241 1242 1243 1244 1245
                return m_commands.end();
            }
        } else if constexpr (std::is_same_v<T, GetValue>) {
            return m_commands.end();
        }
1246
        size_t buffer_length = state.options.buffer_length;
1247 1248
        if (m_commands.size() > buffer_length) {
            return m_commands.begin() + (m_commands.size() - buffer_length);
1249 1250
        }
        return m_commands.begin();
1251
    }, cmd.data);
1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263
}

/**
 * 1. Find ApplyOp(dest) in buffered commands
 * 2. Check if there are other usages between ApplyOp and Del, return false if not
 * 3. Fuse Del into ApplyOp, return true
 */
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
    auto* dest = cmd.dest;
    // TODO: eliminate Puts
    auto begin = m_commands.begin(), end = m_commands.end();
    auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
1264
        if (auto* apply = std::get_if<ApplyOp>(&cmd.data)) {
1265 1266 1267 1268 1269 1270 1271
            return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
        }
        return false;
    });
    if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
        return false;
    }
1272
    // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
1273
    std::get<ApplyOp>(apply_iter->data).dels.push_back(dest);
1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299
    return true;
}

auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
        -> Handle {
    auto found = range[1];
    for (auto iter = range[0]; iter != range[1]; ++iter) {
        std::visit([&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
                               dest) > 0) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, GetValue>) {
                if (cmd.dest == dest) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, SwapIn> ||
                    std::is_same_v<T, SwapOut> ||
                    std::is_same_v<T, Drop>) {
                //TODO: ignore swap-like commands, just remove them from buffer
                if (cmd.dest == dest) {
                    found = iter;
                }
            }
1300
        }, iter->data);
1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315
    };
    return found;
}

auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
        -> Handle {
    return std::find_if(range[0], range[1], [dest](auto& cmd) {
        return std::visit([dest](const auto& cmd){
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
            } else if constexpr (std::is_same_v<T, Put>) {
                return cmd.dest == dest;
            }
            return false;
1316
        }, cmd.data);
1317 1318
    });
}
1319

1320
void ChannelImpl::start_profile() {
1321
    MGB_LOCK_GUARD(m_spin);
1322
    mgb_assert(check_available(), "Channel already closed");
1323 1324 1325 1326
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
        m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
    }
1327 1328
}

1329
void ChannelImpl::stop_profile() {
1330
    MGB_LOCK_GUARD(m_spin);
1331
    mgb_assert(check_available(), "Channel already closed");
1332
    m_buffer.flush();
1333 1334 1335 1336
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
        m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
    }
1337 1338 1339
}

void ChannelImpl::push_scope(std::string name) {
1340
    MGB_LOCK_GUARD(m_spin);
1341
    mgb_assert(check_available(), "Channel already closed");
1342
    auto& state = get_channel_state();
1343
    state.stack_manager.enter(name);
1344
    MGB_RECORD_EVENT(ScopeEvent, name);
1345
    m_buffer.enqueue(PushScope{name});
1346 1347 1348
}

void ChannelImpl::pop_scope(std::string name) {
1349
    MGB_LOCK_GUARD(m_spin);
1350
    mgb_assert(check_available(), "Channel already closed");
1351
    auto& state = get_channel_state();
1352
    state.stack_manager.exit(name);
1353
    MGB_RECORD_EVENT(ScopeFinishEvent, name);
1354
    m_buffer.enqueue(PopScope{name});
1355 1356
}

1357 1358 1359 1360 1361 1362 1363 1364
void ChannelImpl::assert_in_channel() {
    mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread");
}

void ChannelImpl::assert_in_worker() {
    mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread");
}

1365 1366 1367 1368 1369 1370 1371 1372
void ChannelImpl::sample_on_device(CompNode device, bool force) {
    if (!force) {
        thread_local int last_sample_id = 0;
        int sample_rate = Profiler::is_profiling() ? Profiler::get_option("sample_rate", 0) : 0;
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
1373
    MGB_RECORD_EVENT(SampleDeviceEvent, device);
1374
    auto [total, free] = device.get_mem_status_bytes();
1375
    MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
1376 1377
}

1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
    }
}

void ChannelImpl::DynamicSublinear::unpin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->unpin();
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

1433
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(bool enable_dtr_sqrt_sampling=false) {
1434 1435
    double min_msps = -1;
    TensorInfo* best = nullptr;
1436 1437 1438 1439 1440 1441
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
        while (sz * sz <= candidates.size()) sz ++;
    } else {
        sz = candidates.size();
    }
1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453
    for (auto i : candidates) {
        if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
            double neighbor_cost = estimate_neighbor_cost(i);
            size_t begin_ptr = reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(begin_ptr, begin_ptr + i->ptr->blob()->size());
            double free_mem = side_info.first + side_info.second;
            double msps = i->eval_func(neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1454
        if (--sz == 0) break;
1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491
    }
    return best;
}

void ChannelImpl::DynamicSublinear::merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y) {
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(std::shared_ptr<DsuNode>& x) {
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
    candidates.insert(ptr);
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
    candidates.erase(ptr);
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}