interpreter_impl.cpp 42.1 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14
#include "megbrain/common.h"
15 16
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
17 18
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
19 20
#include "megbrain/imperative/utils/to_string.h"

21 22 23 24 25
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

26 27 28 29 30 31 32 33 34 35 36
#define RECORD_EVENT(type, ...) \
    if (state.profiler->is_profiling()) { \
        state.profiler->record_host<type>(type{__VA_ARGS__}); \
    } \

#define RECORD_DEVICE_EVENT(type, device, ...) \
    if (state.profiler->is_profiling()) { \
        state.profiler->record_device<type>((device), type{__VA_ARGS__}); \
    } \


37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

#define m_channel_state
#define m_worker_state

54 55 56 57 58 59 60 61 62
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

63
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
64
    mgb_assert(check_available(), "Channel already closed");
65 66 67 68
    auto info = alloc();
    info->desc.layout = value.layout();
    info->desc.comp_node = value.comp_node();
    info->desc.value = value.proxy_to_default_cpu();
69
    info->h_value = value;
70
    m_buffer.enqueue(Put{info, value, no_cache});
71 72 73 74
    if (m_async_level == 0) {
        sync();
        info->desc.comp_node.sync();
    }
75 76 77
    return info;
}

78
Handle ChannelImpl::put(const DeviceTensorND& data) {
79
    auto& state = get_channel_state();
80
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
81 82 83 84
    auto info = alloc();
    info->desc.layout = data.layout();
    info->desc.comp_node = data.comp_node();
    info->ptr = Tensor::make(data);
85
    RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node);
M
Megvii Engine Team 已提交
86 87 88
    return info;
}

89
void ChannelImpl::del(Handle handle) {
90 91 92
    if (!check_available()){
        return;
    }
93 94 95 96
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
    m_buffer.enqueue(Del{info});
97 98
}

99
void ChannelImpl::swap_in(Handle handle) {
100
    mgb_assert(check_available(), "Channel already closed");
101 102
    auto& state = get_channel_state();
    if (state.options.enable_swap) {
103 104
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
105 106
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapIn{info});
107 108 109
    }
}

110
void ChannelImpl::swap_out(Handle handle) {
111
    mgb_assert(check_available(), "Channel already closed");
112 113
    auto& state = get_channel_state();
    if (state.options.enable_swap) {
114 115
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
116 117
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapOut{info});
118 119 120
    }
}

121
void ChannelImpl::drop(Handle handle) {
122
    mgb_assert(check_available(), "Channel already closed");
123 124
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
125 126
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
127 128
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(Drop{info});
129 130 131
    }
}

132
void ChannelImpl::dispatch_default_cpu(
133
        std::shared_ptr<OpDef> op,
134 135 136
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
137
    auto& state = get_channel_state();
138
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
139
    MGB_MARK_USED_VAR(validated);
140

141 142 143
    SmallVector<DeviceTensorND> input_tensornds;
    input_tensornds.reserve(input_descs.size());
    CompNode output_cn;
144 145
    {
        MGB_LOCK_GUARD(m_mutex);
146
        for (auto&& info : input_infos) {
147
            auto input_cn = info->desc.comp_node;
148
            if (!output_cn.valid()) {
149 150 151 152 153 154 155
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
                input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
156
            } else {
157 158
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
159 160 161 162 163 164 165 166 167 168 169 170 171 172
            }
        }
    }

    outputs->reserve(output_descs.size());
    SmallVector<DeviceTensorND> output_tensornds;
    output_tensornds.reserve(output_descs.size());
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
        output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
    }

173 174 175 176 177 178 179
    auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
        SmallVector<uint64_t> tid;
        for (auto* ptinfo: tinfo) {
            tid.push_back(ptinfo->id);
        }
        return tid;
    };
180 181
    auto apply_id = ++m_last_id;
    RECORD_EVENT(OpExecuteEvent, apply_id, op, tinfo_to_tid(input_infos), {});
182

183 184 185 186 187 188 189
    OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);

    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());
    for (auto&& tensornd : output_tensornds) {
        HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
            .proxy_to_comp_node(output_cn);
190 191 192
        // use `put` for consistency
        auto info = reinterpret_cast<TensorInfo*>(put(host_tensornd, false));
        mgb_assert(info->desc.layout.ndim != 0);
193 194 195
        output_infos.push_back(info);
        outputs->push_back(info);
    }
196

197 198
    RECORD_EVENT(OpExecuteFinishEvent, apply_id, op, 
            tinfo_to_tid(input_infos), tinfo_to_tid(output_infos));
199
}
200

201 202 203 204 205
void ChannelImpl::dispatch_kernel(
        std::shared_ptr<OpDef> op,
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
206
    auto& state = get_channel_state();
207
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
208

209
    ApplyOp cmd{std::move(op)};
210
    cmd.inputs = std::move(input_infos);
211
    cmd.outputs.reserve(output_descs.size());
212 213
    outputs->reserve(output_descs.size());
    for (auto&& desc : output_descs) {
214 215
        auto info = alloc();
        info->desc = desc;
216 217 218 219 220
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
                .proxy_to_comp_node(desc.comp_node);
        }
221
        cmd.outputs.push_back(info);
222
        outputs->push_back(info);
223
    }
224
    m_buffer.enqueue(std::move(cmd));
225
    if (!validated && state.options.async_level == 1) {
226
        sync();
227
    } else if (state.options.async_level == 0) {
228
        sync();
229
        // check device error
230
        for (auto&& oup : *outputs) {
231 232
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
233
        }
234
    }
235 236 237 238 239
}

SmallVector<Handle> ChannelImpl::apply_op(
        std::shared_ptr<OpDef> op,
        const SmallVector<Handle>& inputs) {
240
    mgb_assert(check_available(), "Channel already closed");
241
    auto& state = get_channel_state();
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
    for (auto i : inputs) {
        mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
                "invalid handle: %p", i);
    }
    SmallVector<TensorInfo*> input_infos;
    input_infos.reserve(inputs.size());
    SmallVector<LogicalTensorDesc> input_descs;
    input_descs.reserve(inputs.size());
    {
        MGB_LOCK_GUARD(m_mutex);
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
            mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
261
    DispatchMode dispatch_mode = state.options.enable_host_compute
262 263 264
            ? OpDef::decide_dispatch_mode(*op, input_descs)
            : DispatchMode::KERNEL;
    switch (dispatch_mode) {
265 266 267 268 269 270 271 272 273
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
274 275 276
    return outputs;
}

277
HostTensorND ChannelImpl::get_value(Handle handle) {
278
    mgb_assert(check_available(), "Channel already closed");
279
    auto& state = get_channel_state();
280
    // TODO: maybe get_value should be done on host. i.e. delete GetValue
281 282 283 284
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    mgb_assert(!m_waitee);
285 286
    // donnot use info->value_fetched, it's unsafe
    mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
287
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
288 289 290 291 292
    TensorPtr tensor_ptr = info->ptr;
    auto value_fetched = [&]() {
        return tensor_ptr && tensor_ptr->value_fetched();
    };
    if (!value_fetched()) {
293
        m_waitee = info;
294
        m_buffer.enqueue(GetValue{info});
295
        RECORD_EVENT(TensorWaitPropEvent, info->id, TensorInfo::HostValue);
296 297
        m_cv.wait(lock, [&]() {
            check_worker_exc_unsafe();
298 299
            tensor_ptr = info->ptr;
            return value_fetched();
300
        });
301
        RECORD_EVENT(TensorWaitPropFinishEvent, info->id, TensorInfo::HostValue);
302 303
        m_waitee = nullptr;
    }
304
    return tensor_ptr->get_value();
305 306
}

307
TensorShape ChannelImpl::get_shape(Handle handle) {
308
    mgb_assert(check_available(), "Channel already closed");
309
    auto& state = get_channel_state();
310 311 312 313 314 315 316 317 318
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee);
    m_waitee = info;
319
    m_buffer.flush();
320
    RECORD_EVENT(TensorWaitPropEvent, info->id, TensorInfo::Shape);
321 322
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
323
        return static_cast<bool>(info->ptr);
324
    });
325
    RECORD_EVENT(TensorWaitPropFinishEvent, info->id, TensorInfo::Shape);
326 327 328 329 330 331
    m_waitee = nullptr;
    TensorShape ret = info->ptr->layout();
    mgb_assert(ret.ndim != 0);
    return ret;
}

332
DType ChannelImpl::get_dtype(Handle handle) {
333
    mgb_assert(check_available(), "Channel already closed");
334
    auto& state = get_channel_state();
335 336 337
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
338
    RECORD_EVENT(TensorGetPropEvent, info->id, TensorInfo::DType);
339 340 341 342 343
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

344
CompNode ChannelImpl::get_device(Handle handle) {
345
    mgb_assert(check_available(), "Channel already closed");
346
    auto& state = get_channel_state();
347 348 349
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
350
    RECORD_EVENT(TensorGetPropEvent, info->id, TensorInfo::Device);
351 352 353 354 355
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

356
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
357
    mgb_assert(check_available(), "Channel already closed");
358
    auto& state = get_channel_state();
359 360 361 362 363 364
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee);
    m_waitee = info;
365
    m_buffer.flush();
366
    RECORD_EVENT(TensorWaitPropEvent, info->id, TensorInfo::DevValue);
367 368
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
369
        return static_cast<bool>(info->ptr);
370
    });
371
    RECORD_EVENT(TensorWaitPropFinishEvent, info->id, TensorInfo::DevValue);
372 373 374 375 376
    m_waitee = nullptr;
    return info->ptr->dev_tensor();
}

void ChannelImpl::sync() {
377
    mgb_assert(check_available(), "Channel already closed");
378
    auto& state = get_channel_state();
379
    m_buffer.flush();
380
    RECORD_EVENT(SyncEvent);
381
    m_worker.wait_all_task_finish();
382
    CompNode::sync_all();
383
    RECORD_EVENT(SyncFinishEvent);
384 385 386 387 388
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
389 390 391 392 393 394 395 396 397
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
    for (auto* handle: valid_handles) {
        del(handle);
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
398
    sync();
399
    m_closed = true;
400 401
}

402
size_t ChannelImpl::get_option(std::string name) {
403
    mgb_assert(check_available(), "Channel already closed");
404 405
    auto& state = get_channel_state();
    return state.options.get_option(name);
406 407
}

408
void ChannelImpl::set_option(std::string name, size_t value) {
409
    mgb_assert(check_available(), "Channel already closed");
410 411
    auto& state = get_channel_state();
    state.options.set_option(name, value);
412
    m_buffer.enqueue(SetOption{name, value});
413 414 415
}

TensorInfo* ChannelImpl::alloc() {
416
    auto& state = get_channel_state();
417
    MGB_LOCK_GUARD(m_mutex);
418
    auto info = m_pool.alloc();
419
    m_valid_handle.insert(info);
420
    info->id = m_last_id++;
421
    RECORD_EVENT(TensorDeclareEvent, info->id);
422
    return info;
423 424
}

425 426 427 428 429 430 431 432 433 434 435 436 437 438 439

void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) {
    if (!ptr->producer) {
        if (user) {
            mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", ptr);
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
    release_tensor(ptr);
}

440
void ChannelImpl::free(TensorInfo* ptr) {
441 442
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
    SmallVector<TensorInfo*> inps(0);
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
}

void ChannelImpl::real_free(TensorInfo* ptr) {
475
    auto& state = get_worker_state();
476
    MGB_LOCK_GUARD(m_mutex);
477
    RECORD_EVENT(TensorEraseEvent, ptr->id);
478
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
479 480 481 482
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
483 484 485
    m_pool.free(ptr);
}

486
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
487

488 489 490
ChannelImpl::~ChannelImpl() {
    close();
}
491

492
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) {
493 494 495 496 497
    auto& state = get_worker_state();
    auto lock = std::unique_lock<std::mutex>(m_mutex, std::defer_lock);
    if (notice) {
        lock.lock();
    }
498
    m_dtr.update_used_time(dest);
499 500
    if (notice) {
        RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node());
501
    }
502 503 504 505
    dest->value_fetched = ptr->value_fetched();
    // update tensor desc for static infer
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
506
    dest->memory = ptr->blob()->size();
507
    dest->ptr = std::move(ptr);
508
    dest->evict_type = EvictType::NONE;
509
    if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
510 511 512
        m_dtr.insert_candidate(dest);
    }
    if (notice && m_waitee == dest) {
513
        m_cv.notify_all();
514 515 516
    }
}

517 518 519 520 521
void ChannelImpl::release_tensor(TensorInfo* dest) {
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
}

522
void ChannelImpl::regenerate(TensorInfo* dest) {
523
    if (dest->evict_type == EvictType::DROP) {
524
        recompute(dest->producer);
525 526
    } else if (dest->evict_type == EvictType::SWAP) {
        produce_tensor(dest, Tensor::make(dest->h_value));
527 528 529
    }
}

530
void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
531
    auto& state = get_worker_state();
532 533 534 535 536 537 538 539 540 541
    SmallVector<TensorPtr> inputs;
    inputs.reserve(path->inputs.size());
    m_dtr.pin(path->inputs);
    for (auto i : path->inputs) {
        if (!i->ptr) {
            regenerate(i);
        }
        inputs.push_back(i->ptr);
        m_dtr.update_used_time(i);
    }
542
    if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
543 544 545 546 547 548 549 550 551 552 553
        auto_evict();
    }
    auto outputs = OpDef::apply_on_physical_tensor(*path->op, inputs);
    m_dtr.estimate_timestamp += path->compute_time / 1e8;
    m_dtr.unpin(path->inputs);
    for (size_t i = 0;i < outputs.size();i ++) {
        auto&& o = path->outputs[i];
        if (o) {
            o->recompute_times ++;
            if (!o->ptr) {
                produce_tensor(o, std::move(outputs[i]), false);
554
                if (state.options.enable_dtr_auto_drop) {
555 556 557 558
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
559
    }
560 561 562
}

void ChannelImpl::auto_evict() {
563
    auto& state = get_worker_state();
564 565 566 567
    if (!m_dtr.comp_node.valid()) {
        return;
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
568
    while (current_memory > state.options.dtr_eviction_threshold) {
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
        auto best = m_dtr.find_best_tensor();
        if (!best) {
            if (!m_dtr.warn_printed) {
                m_dtr.warn_printed = true;
                mgb_log_warn("No tensors on %s can be evicted automatically "
                             "when memory usage is %.0lfMB. Maybe memory "
                             "budget is too small.",
                              m_dtr.comp_node.to_string().c_str(),
                              current_memory / 1024.0 / 1024.0);
            }
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
587 588 589 590
        }
    }
}

591 592 593
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
    for (auto* user: users) {
594 595 596
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
        for (auto* output: outputs) {
597 598 599 600 601
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
602 603 604
            for (auto* input: inputs) {
                input->ref_cnt --;
            }
605
        }
606
    }
607 608
    mgb_assert(dest->users.size() == 0);
    //dest->users.clear();
609 610
}

611 612 613 614
bool ChannelImpl::check_available() {
    return !m_closed;
}

615
void ChannelImpl::sync_device_scope(CompNode device) {
616 617 618
    auto& state = get_worker_state();
    auto& prev = state.device_scope_map[device];
    auto& current = state.scopes;
619
    auto push_scope = [&](std::string name) {
620
        RECORD_DEVICE_EVENT(DeviceScopeEvent, device, name);
621 622
    };
    auto pop_scope = [&](std::string name) {
623
        RECORD_DEVICE_EVENT(DeviceScopeFinishEvent, device, name);
624 625 626 627 628 629 630
    };
    size_t similarity = 0;
    for (size_t i = 0; i < prev.size() && i < current.size(); i++) {
        if (prev[i] == current[i]) {
            similarity++;
        } else {
            break;
631 632
        }
    }
633 634 635
    while (prev.size() > similarity) {
        pop_scope(prev.back());
        prev.pop_back();
636
    }
637 638 639
    while (prev.size() < current.size()) {
        prev.push_back(current[prev.size()]);
        push_scope(prev.back());
640 641 642
    }
}

643
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
644
    auto& state = get_worker_state();
645
    RECORD_EVENT(CommandExecuteEvent, icmd);
646 647 648 649 650
    bool finished = false;
    auto do_finish_command = [&]{
        if (finished) {
            return;
        }
651
        RECORD_EVENT(CommandFinishEvent, icmd);
652 653
        finished = true;
    };
654
    //TODO: remove std::visit for support osx 10.12
655 656
    auto cmd_visitor = [&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
657
            if constexpr (std::is_same_v<T, Put>) {
658 659
                auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
                produce_tensor(cmd.dest, std::move(value));
660
            } else if constexpr (std::is_same_v<T, ApplyOp>) {
661
                uint64_t apply_id = ++m_last_id;
662
                SmallVector<TensorPtr> tensor_inputs;
663
                SmallVector<CompNode> devices;
664
                if (state.options.enable_dtr_auto_drop) {
665 666 667 668 669 670 671 672
                    m_dtr.pin(cmd.inputs);
                }
                for (auto i : cmd.inputs) {
                    if (!i->ptr && i->evict_type != EvictType::NONE) {
                        regenerate(i);
                    }
                    m_dtr.update_used_time(i);
                }
673
                tensor_inputs.reserve(cmd.inputs.size());
674
                // refcnt == 1, owners: [TensorInfo::ptr]
675
                for (auto i : cmd.inputs) {
676
                    mgb_assert(i->ptr, "Invalid input tensor ptr!");
677
                    // refcnt ++, owners: [i->ptr, tensor_inputs]
678 679
                    tensor_inputs.push_back(i->ptr);
                }
680
                // Begin profiling operator
681 682 683 684 685 686 687
                auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
                    SmallVector<uint64_t> tid;
                    for (auto* ptinfo: tinfo) {
                        tid.push_back(ptinfo->id);
                    }
                    return tid;
                };
688
                if (state.profiler->is_profiling()) {
689 690 691
                    // Collecting devices
                    for (auto i : cmd.inputs) {
                        devices.push_back(i->desc.comp_node);
692
                    }
693 694 695 696
                    for (auto i : cmd.outputs) {
                        devices.push_back(i->desc.comp_node);
                    }
                    devices.erase(std::unique(devices.begin(), devices.end()), devices.end());
697
                }
698 699 700 701 702 703 704 705
                // Fused by command buffer. @see: CommandBuffer::fuse_del
                // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
                // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
                for (auto* del : cmd.dels) {
                    // refcnt --, owners: [tensor_inputs]
                    // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
                    free(del);
                }
706 707 708
                // Before wait
                //TODO: split operator wait and execute so that OpWait could be corrected recorded.
                // Before execute
709
                RECORD_EVENT(OpExecuteEvent, apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
710
                if (state.profiler->is_profiling()) {
711 712
                    for (auto&& device: devices) {
                        sync_device_scope(device);
713 714
                        RECORD_DEVICE_EVENT(KernelExecuteEvent, device, apply_id, cmd.op,
                                tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
715
                    }
716
                }
717
                if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
718 719
                    auto_evict();
                }
720
                // Apply op
721 722 723
                // Here std::move is REQUIRED for removing duplicated references.
                auto tensor_outputs = OpDef::apply_on_physical_tensor(
                    *cmd.op, std::move(tensor_inputs));
724
                // After execute
725
                RECORD_EVENT(OpExecuteFinishEvent, apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
726
                if (state.profiler->is_profiling()) {
727
                    for (auto&& device: devices) {
728
                        RECORD_DEVICE_EVENT(KernelExecuteFinishEvent, device, apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
729
                    }
730 731
                }
                // End profiling operator
732
                double estimate_compute_time = 0;
733
                if (state.options.enable_dtr_auto_drop) {
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749
                    for (auto i : cmd.inputs) {
                        estimate_compute_time += i->memory;
                    }
                    for (auto i : tensor_outputs) {
                        estimate_compute_time += i->blob()->size();
                    }
                    m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
                    for (auto i : cmd.outputs) {
                        i->compute_time = estimate_compute_time;
                        m_dtr.update_used_time(i);
                    }
                    if (cmd.outputs[0]->producer) {
                        cmd.outputs[0]->producer->compute_time = estimate_compute_time;
                    }
                    m_dtr.unpin(cmd.inputs);
                }
750 751
                mgb_assert(tensor_outputs.size() == cmd.outputs.size());
                for (size_t i = 0; i < tensor_outputs.size(); ++i) {
752 753 754
                    if (cmd.outputs[i] == nullptr) {
                        continue;
                    }
755
                    produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
756
                    if (state.options.enable_dtr_auto_drop) {
757 758 759
                        cmd.outputs[i]->dsu_ptr = std::make_shared<DsuNode>(estimate_compute_time);
                    }
                }
760 761
                if (state.options.enable_drop == 1
                    && state.options.record_computing_path == 1){
762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783
                    bool is_inplace = false;
                    bool cross_cn = false;
                    for (auto input : cmd.inputs) {
                        for (auto output : cmd.outputs) {
                            if (input->ptr->blob()->storage() == output->ptr->blob()->storage()) {
                                is_inplace = true;
                                break;
                            }
                        }
                    }
                    for (auto input : cmd.inputs) {
                        if (input->ptr->comp_node() != m_dtr.comp_node) {
                            cross_cn = true;
                            break;
                        }
                    }
                    for (auto output : cmd.outputs) {
                        if (output->ptr->comp_node() != m_dtr.comp_node) {
                            cross_cn = true;
                            break;
                        }
                    }
784 785 786 787 788 789 790 791
                    // FIXME: do not use opname as identifier
                    auto get_name = [](const OpDef& opdef) {
                        if (auto attr = opdef.try_cast_final<OprAttr>()) {
                            return attr->type.c_str();
                        }
                        return opdef.dyn_typeinfo()->name;
                    };
                    if (!is_inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
792 793 794
                        TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
                        size_t detach_cnt = 0;
                        for (auto output : cmd.outputs) {
795
                            if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
796 797 798 799 800 801 802 803
                                output->detach_producer();
                                detach_cnt ++;
                            }
                        }
                        for (auto input : cmd.inputs) {
                            input->ref_cnt -= detach_cnt;
                        }
                    }
804 805 806 807
                }
            } else if constexpr (std::is_same_v<T, Del>) {
                free(cmd.dest);
            } else if constexpr (std::is_same_v<T, GetValue>) {
808 809 810
                if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                    regenerate(cmd.dest);
                }
811
                mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
812 813 814 815 816 817
                cmd.dest->ptr->fetch_value();
                MGB_LOCK_GUARD(m_mutex);
                cmd.dest->value_fetched = true;
                if (m_waitee == cmd.dest) {
                    m_cv.notify_all();
                }
818
            } else if constexpr (std::is_same_v<T, SwapIn>) {
819
                produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
820
            } else if constexpr (std::is_same_v<T, SwapOut>) {
821
                cmd.dest->h_value = cmd.dest->ptr->get_value();
822 823 824 825
                if (cmd.dest->evict_type == EvictType::NONE) {
                    release_tensor(cmd.dest);
                    cmd.dest->evict_type = EvictType::SWAP;
                }
826
            } else if constexpr (std::is_same_v<T, Drop>) {
827
                do_drop(cmd.dest, true);
828
            } else if constexpr (std::is_same_v<T, SetOption>) {
829
                state.options.set_option(cmd.key, cmd.value);
830 831
            } else if constexpr (std::is_same_v<T, StartProfile>) {
                CompNode::sync_all();
832
                state.profiler.reset(cmd.profiler);
833
            } else if constexpr (std::is_same_v<T, StopProfile>) {
834
                for (auto&& [device, scopes]: state.device_scope_map) {
835 836 837 838 839
                    MGB_MARK_USED_VAR(scopes);
                    sync_device_scope(device);
                }
                do_finish_command();
                auto profiler = std::make_unique<InterpreterProfiler>();
840
                std::swap(profiler, state.profiler);
841
                auto records = profiler->stop();
842 843 844
                auto worker_tid = get_worker_tid();
                auto host_map = [worker_tid](std::thread::id tid) {
                    if (tid == worker_tid) {
845 846 847 848 849 850 851
                        return "worker";
                    } else {
                        return "unknown";
                    }
                };
                InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map);
            } else if constexpr (std::is_same_v<T, PushScope>) {
852
                state.scopes.push_back(cmd.scope_name);
853
                do_finish_command();
854
                RECORD_EVENT(ScopeEvent, cmd.scope_name);
855
            } else if constexpr (std::is_same_v<T, PopScope>) {
856 857
                mgb_assert(state.scopes.back() == cmd.scope_name, "scope name mismatch");
                state.scopes.pop_back();
858
                do_finish_command();
859
                RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
860
            } else {
861
                static_assert(!std::is_same_v<T, T>);
862
            }
863
    };
864
    std::visit([&](const auto& cmd){
865
        using T = std::decay_t<decltype(cmd)>;
866
        if (!state.options.catch_worker_execption) {
867 868 869 870 871
            cmd_visitor(cmd);
            return;
        }
        try {
            cmd_visitor(cmd);
872 873
        } catch (...) {
            MGB_LOCK_GUARD(m_mutex);
874 875 876 877 878 879 880
            if constexpr (std::is_same_v<T, ApplyOp>) {
                for (auto oup : cmd.outputs) {
                    oup->invalid = true;
                }
            } else if constexpr (std::is_same_v<T, Put>) {
                cmd.dest->invalid = true;
            }
881 882 883
            m_worker_exc = std::current_exception();
            m_cv.notify_all();
        }
884 885
    }, icmd.second);
    do_finish_command();
886 887 888 889
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
890 891
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
892 893 894 895 896
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
        std::rethrow_exception(exc);
    }
}
897 898 899 900 901

void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
    if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
        return;
    }
902
    // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
903 904 905 906 907
    m_commands.push_back(std::move(cmd));
    auto flush_pos = flush_pos_for(m_commands.back());
    flush(flush_pos);
}

908 909 910 911
void ChannelImpl::CommandBuffer::flush() {
    flush(m_commands.end());
}

912
void ChannelImpl::CommandBuffer::flush(Handle pos) {
913
    auto& state = m_owner->get_channel_state();
914
    for (auto iter = m_commands.begin(); iter != pos; ++iter) {
915
        // mgb_log_debug("%s Flushed", to_string(*iter).c_str());
916
        IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)};
917
        RECORD_EVENT(CommandEnqueueEvent, icmd);
918
        m_owner->m_worker.add_task(std::move(icmd));
919 920 921 922 923
    }
    m_commands.erase(m_commands.begin(), pos);
}

auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
924 925
    auto& state = m_owner->get_channel_state();
    return std::visit([&, this](const auto& cmd) {
926 927 928 929 930 931 932
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, ApplyOp>) {
            auto* op_type = cmd.op->dyn_typeinfo();
            if (op_type == RemoteRecv::typeinfo() ||
                op_type == RemoteSend::typeinfo() ||
                op_type == CollectiveComm::typeinfo() ||
                op_type == opr::InputCallback::typeinfo() ||
933
                op_type == opr::OutputCallback::typeinfo()) {
934 935 936 937 938
                return m_commands.end();
            }
        } else if constexpr (std::is_same_v<T, GetValue>) {
            return m_commands.end();
        }
939
        size_t buffer_length = state.options.buffer_length;
940 941
        if (m_commands.size() > buffer_length) {
            return m_commands.begin() + (m_commands.size() - buffer_length);
942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964
        }
        return m_commands.begin();
    }, cmd);
}

/**
 * 1. Find ApplyOp(dest) in buffered commands
 * 2. Check if there are other usages between ApplyOp and Del, return false if not
 * 3. Fuse Del into ApplyOp, return true
 */
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
    auto* dest = cmd.dest;
    // TODO: eliminate Puts
    auto begin = m_commands.begin(), end = m_commands.end();
    auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
        if (auto* apply = std::get_if<ApplyOp>(&cmd)) {
            return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
        }
        return false;
    });
    if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
        return false;
    }
965
    // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
    std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
    return true;
}

auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
        -> Handle {
    auto found = range[1];
    for (auto iter = range[0]; iter != range[1]; ++iter) {
        std::visit([&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
                               dest) > 0) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, GetValue>) {
                if (cmd.dest == dest) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, SwapIn> ||
                    std::is_same_v<T, SwapOut> ||
                    std::is_same_v<T, Drop>) {
                //TODO: ignore swap-like commands, just remove them from buffer
                if (cmd.dest == dest) {
                    found = iter;
                }
            }
        }, *iter);
    };
    return found;
}

auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
        -> Handle {
    return std::find_if(range[0], range[1], [dest](auto& cmd) {
        return std::visit([dest](const auto& cmd){
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
            } else if constexpr (std::is_same_v<T, Put>) {
                return cmd.dest == dest;
            }
            return false;
        }, cmd);
    });
}
1012 1013

void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
1014
    mgb_assert(check_available(), "Channel already closed");
1015
    auto& state = get_channel_state();
1016 1017 1018 1019
    auto profiler_option = InterpreterProfiler::Option::from_dict(option);
    auto profiler = std::make_unique<InterpreterProfiler>();
    profiler->set_option(profiler_option);
    profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic));
1020 1021
    std::swap(profiler, state.profiler);
    m_buffer.enqueue(StartProfile{state.profiler.get()});
1022 1023 1024
}

void ChannelImpl::stop_profile(std::string basename, std::string format) {
1025
    mgb_assert(check_available(), "Channel already closed");
1026
    auto& state = get_channel_state();
1027 1028
    m_buffer.flush();
    auto profiler = std::make_unique<InterpreterProfiler>();
1029
    std::swap(profiler, state.profiler);
1030 1031 1032 1033 1034
    profiler.release();
    m_buffer.enqueue(StopProfile{basename, format});
}

void ChannelImpl::push_scope(std::string name) {
1035
    mgb_assert(check_available(), "Channel already closed");
1036
    auto& state = get_channel_state();
1037
    RECORD_EVENT(ScopeEvent, name);
1038 1039
    if (state.profiler->is_profiling()) {
        state.scopes.push_back(name);
1040 1041
        m_buffer.enqueue(PushScope{name});
    }
1042 1043 1044
}

void ChannelImpl::pop_scope(std::string name) {
1045
    mgb_assert(check_available(), "Channel already closed");
1046
    auto& state = get_channel_state();
1047
    RECORD_EVENT(ScopeFinishEvent, name);
1048 1049 1050
    if (state.profiler->is_profiling()) {
        mgb_assert((!state.scopes.empty()) && state.scopes.back() == name, "scope name mismatch");
        state.scopes.pop_back();
1051 1052
        m_buffer.enqueue(PopScope{name});
    }
1053 1054
}

1055 1056 1057 1058 1059 1060 1061 1062
void ChannelImpl::assert_in_channel() {
    mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread");
}

void ChannelImpl::assert_in_worker() {
    mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread");
}

1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
    }
}

void ChannelImpl::DynamicSublinear::unpin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->unpin();
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() {
    double min_msps = -1;
    TensorInfo* best = nullptr;
    for (auto i : candidates) {
        if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
            double neighbor_cost = estimate_neighbor_cost(i);
            size_t begin_ptr = reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(begin_ptr, begin_ptr + i->ptr->blob()->size());
            double free_mem = side_info.first + side_info.second;
            double msps = i->eval_func(neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
    }
    return best;
}

void ChannelImpl::DynamicSublinear::merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y) {
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(std::shared_ptr<DsuNode>& x) {
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
    candidates.insert(ptr);
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
    candidates.erase(ptr);
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}