interpreter_impl.cpp 45.0 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21 22
#include "megbrain/imperative/utils/to_string.h"

23 24 25
#include "../event_pool.h"
#include "../op_trait.h"

26 27 28 29 30
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

31
#define RECORD_EVENT(type, ...) \
32 33
    if (Profiler::is_profiling()) { \
        Profiler::record<type>(type{__VA_ARGS__}); \
34 35 36
    } \


37 38 39 40 41 42 43 44 45 46
namespace {
    auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
        SmallVector<uint64_t> tid;
        for (auto* ptinfo: tinfo) {
            tid.push_back(ptinfo->id);
        }
        return tid;
    };
}

47 48 49 50
namespace mgb {
    using namespace profiler;
}

51 52 53 54 55
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
56 57 58 59 60 61 62

namespace mgb {

/**
 * USAGE
 *
 *   header:
63
 *     namespace mgb { void imperative_log_profile(const char* message); }
64 65 66 67 68
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
69
SYMBOL_EXPORT
70 71 72 73
void imperative_log_profile_begin(const char* message) {
    RECORD_EVENT(CustomEvent, std::string{message});
}

74
SYMBOL_EXPORT
75 76 77 78
void imperative_log_profile_end(const char* message) {
    RECORD_EVENT(CustomFinishEvent, std::string{message});
}

79
SYMBOL_EXPORT
80 81 82 83 84 85 86
void imperative_log_profile(const char* message){
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

}

87 88 89 90 91 92 93 94 95 96 97 98 99 100
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

101
// Do not use m_xxx_state directly
102 103 104
#define m_channel_state
#define m_worker_state

105 106 107 108 109 110 111 112 113
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

114
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
115
    mgb_assert(check_available(), "Channel already closed");
116 117 118 119 120 121 122 123
    auto& state = get_channel_state();
    state.scopes.push("Put");
    auto info = put_impl(value, no_cache);
    state.scopes.pop("Put");
    return info;
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
124
    auto info = alloc();
125
    init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
126
    info->h_value = value;
127
    m_buffer.enqueue(Put{info, value, no_cache});
128 129 130 131
    if (m_async_level == 0) {
        sync();
        info->desc.comp_node.sync();
    }
132 133 134
    return info;
}

135
Handle ChannelImpl::put(const DeviceTensorND& data) {
136
    auto& state = get_channel_state();
137
    mgb_assert(check_available(), "Channel already closed");
138
    state.scopes.push("Put");
M
Megvii Engine Team 已提交
139
    auto info = alloc();
140 141
    RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put);
    init(info, {data.layout(), data.comp_node()});
M
Megvii Engine Team 已提交
142
    info->ptr = Tensor::make(data);
143 144 145 146
    RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
    info->status = TensorInfo::Produced;
    RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put);
    state.scopes.pop("Put");
M
Megvii Engine Team 已提交
147 148 149
    return info;
}

150
void ChannelImpl::del(Handle handle) {
151 152 153
    if (!check_available()){
        return;
    }
154 155 156 157
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
    m_buffer.enqueue(Del{info});
158 159
}

160
void ChannelImpl::swap_in(Handle handle) {
161
    mgb_assert(check_available(), "Channel already closed");
162 163
    auto& state = get_channel_state();
    if (state.options.enable_swap) {
164 165
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
166 167
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapIn{info});
168 169 170
    }
}

171
void ChannelImpl::swap_out(Handle handle) {
172
    mgb_assert(check_available(), "Channel already closed");
173 174
    auto& state = get_channel_state();
    if (state.options.enable_swap) {
175 176
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
177 178
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapOut{info});
179 180 181
    }
}

182
void ChannelImpl::drop(Handle handle) {
183
    mgb_assert(check_available(), "Channel already closed");
184 185
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
186 187
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
188 189
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(Drop{info});
190 191 192
    }
}

193
void ChannelImpl::dispatch_default_cpu(
194
        std::shared_ptr<OpDef> op,
195 196 197
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
198
    auto& state = get_channel_state();
199
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
200
    RECORD_EVENT(ShapeInferEvent, validated);
201

202 203 204
    SmallVector<DeviceTensorND> input_tensornds;
    input_tensornds.reserve(input_descs.size());
    CompNode output_cn;
205 206
    {
        MGB_LOCK_GUARD(m_mutex);
207
        for (auto&& info : input_infos) {
208
            auto input_cn = info->desc.comp_node;
209
            if (!output_cn.valid()) {
210 211 212 213 214 215 216
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
                input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
217
            } else {
218
                // It's OK for SwapOut. We assign h_value before drop ptr
219 220
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
221 222 223 224 225 226 227 228 229 230 231 232 233 234
            }
        }
    }

    outputs->reserve(output_descs.size());
    SmallVector<DeviceTensorND> output_tensornds;
    output_tensornds.reserve(output_descs.size());
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
        output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
    }

235
    uint64_t op_id = Profiler::next_id();
236

237 238 239 240 241 242 243
    OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);

    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());
    for (auto&& tensornd : output_tensornds) {
        HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
            .proxy_to_comp_node(output_cn);
244
        // use `put` for consistency
245
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
246
        mgb_assert(info->desc.layout.ndim != 0);
247 248 249
        output_infos.push_back(info);
        outputs->push_back(info);
    }
250 251 252 253 254 255 256 257 258
    auto op_info_getter = [op]{
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
        for (auto&& [key, value]: props) {
            op_info[key] = value;
        }
        return op_info;
    };
    RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos));
259
}
260

261 262 263 264 265
void ChannelImpl::dispatch_kernel(
        std::shared_ptr<OpDef> op,
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
266
    auto& state = get_channel_state();
267 268 269 270 271
    auto& options = state.options;

    auto name = op->trait()->make_name(*op);
    state.scopes.push(name);

272
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
273
    RECORD_EVENT(ShapeInferEvent, validated);
274

275
    ApplyOp cmd{Profiler::next_id(), std::move(op)};
276
    cmd.inputs = std::move(input_infos);
277
    cmd.outputs.reserve(output_descs.size());
278
    outputs->reserve(output_descs.size());
279 280
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
281
        auto info = alloc();
282
        init(info, desc);
283 284 285 286 287
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
                .proxy_to_comp_node(desc.comp_node);
        }
288
        cmd.outputs.push_back(info);
289
        outputs->push_back(info);
290
    }
291 292 293 294 295 296 297 298 299
    auto op_info_getter = [op=cmd.op]{
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
        for (auto&& [key, value]: props) {
            op_info[key] = value;
        }
        return op_info;
    };
    RECORD_EVENT(OpDispatchEvent, cmd.id, cmd.op->trait()->name, op_info_getter, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
300
    m_buffer.enqueue(std::move(cmd));
301
    if (!validated && options.async_level == 1) {
302
        sync();
303
    } else if (options.async_level == 0) {
304
        sync();
305
        // check device error
306
        for (auto&& oup : *outputs) {
307 308
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
309
        }
310
    }
311
    state.scopes.pop(name);
312 313 314 315 316
}

SmallVector<Handle> ChannelImpl::apply_op(
        std::shared_ptr<OpDef> op,
        const SmallVector<Handle>& inputs) {
317
    mgb_assert(check_available(), "Channel already closed");
318
    auto& state = get_channel_state();
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    for (auto i : inputs) {
        mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
                "invalid handle: %p", i);
    }
    SmallVector<TensorInfo*> input_infos;
    input_infos.reserve(inputs.size());
    SmallVector<LogicalTensorDesc> input_descs;
    input_descs.reserve(inputs.size());
    {
        MGB_LOCK_GUARD(m_mutex);
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
            mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
338
    DispatchMode dispatch_mode = state.options.enable_host_compute
339 340 341
            ? OpDef::decide_dispatch_mode(*op, input_descs)
            : DispatchMode::KERNEL;
    switch (dispatch_mode) {
342 343 344 345 346 347 348 349 350
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
351 352 353
    return outputs;
}

354
HostTensorND ChannelImpl::get_value(Handle handle) {
355
    mgb_assert(check_available(), "Channel already closed");
356
    auto& state = get_channel_state();
357 358 359
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
360 361
    // donnot use info->value_fetched, it's unsafe
    mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
362
    return wait_tensor(info, TensorProp::HostValue)->get_value();
363 364
}

365
TensorShape ChannelImpl::get_shape(Handle handle) {
366
    mgb_assert(check_available(), "Channel already closed");
367
    auto& state = get_channel_state();
368 369 370 371 372 373
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
374
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
375 376 377 378
    mgb_assert(ret.ndim != 0);
    return ret;
}

379
DType ChannelImpl::get_dtype(Handle handle) {
380
    mgb_assert(check_available(), "Channel already closed");
381
    auto& state = get_channel_state();
382 383 384
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
385
    RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
386 387 388 389 390
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

391
CompNode ChannelImpl::get_device(Handle handle) {
392
    mgb_assert(check_available(), "Channel already closed");
393
    auto& state = get_channel_state();
394 395 396
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
397
    RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
398 399 400 401 402
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

403
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
404
    mgb_assert(check_available(), "Channel already closed");
405
    auto& state = get_channel_state();
406 407 408
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
409
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
410 411 412
}

void ChannelImpl::sync() {
413
    mgb_assert(check_available(), "Channel already closed");
414
    auto& state = get_channel_state();
415
    m_buffer.flush();
416 417 418 419 420 421
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
422 423 424 425 426 427 428 429 430
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
    for (auto* handle: valid_handles) {
        del(handle);
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
431
    sync();
432
    m_closed = true;
433 434
}

435
size_t ChannelImpl::get_option(std::string name) {
436
    mgb_assert(check_available(), "Channel already closed");
437 438
    auto& state = get_channel_state();
    return state.options.get_option(name);
439 440
}

441
void ChannelImpl::set_option(std::string name, size_t value) {
442
    mgb_assert(check_available(), "Channel already closed");
443 444
    auto& state = get_channel_state();
    state.options.set_option(name, value);
445
    m_buffer.enqueue(SetOption{name, value});
446 447 448
}

TensorInfo* ChannelImpl::alloc() {
449
    auto& state = get_channel_state();
450 451 452 453 454 455 456 457
    auto info = [this]{
        MGB_LOCK_GUARD(m_mutex);
        return m_pool.alloc();
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
        info->name = state.scopes.next_tensor_name();
    }
458
    return info;
459 460
}

461 462 463 464 465 466 467
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
    m_valid_handle.insert(info);
    RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
    info->status = TensorInfo::Allocated;
    info->desc = std::move(desc);
}

468 469 470 471 472 473 474 475 476 477 478 479

void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) {
    if (!ptr->producer) {
        if (user) {
            mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", ptr);
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
480
    ptr->status = TensorInfo::Dropped;
481 482 483
    release_tensor(ptr);
}

484
void ChannelImpl::free(TensorInfo* ptr) {
485 486
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
502 503
    RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandEvent::RecFree);
    SmallVector<TensorInfo*> inps;
504 505 506 507 508 509 510 511 512 513 514 515 516
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
517
    RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandFinishEvent::RecFree);
518 519 520
}

void ChannelImpl::real_free(TensorInfo* ptr) {
521
    auto& state = get_worker_state();
522
    MGB_LOCK_GUARD(m_mutex);
523
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
524 525 526 527
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
528 529 530 531 532 533
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
        RECORD_EVENT(TensorReleaseEvent, ptr->id);
    }
    RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
    ptr->status = TensorInfo::Deleted;
534 535 536
    m_pool.free(ptr);
}

537
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
538

539 540 541
ChannelImpl::~ChannelImpl() {
    close();
}
542

543
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) {
544
    auto& state = get_worker_state();
545
    std::unique_lock<std::mutex> lock{m_mutex, std::defer_lock};
546 547 548
    if (notice) {
        lock.lock();
    }
549
    m_dtr.update_used_time(dest);
550
    RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr());
551 552 553
    // update tensor desc for static infer
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
554
    dest->memory = ptr->blob()->size();
555
    dest->ptr = std::move(ptr);
556
    dest->evict_type = EvictType::NONE;
557
    dest->status = TensorInfo::Produced;
558
    if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
559 560
        m_dtr.insert_candidate(dest);
    }
561 562
    if (notice) {
        notify_tensor_unsafe(dest);
563 564 565
    }
}

566
void ChannelImpl::release_tensor(TensorInfo* dest) {
567
    RECORD_EVENT(TensorReleaseEvent, dest->id);
568 569 570 571
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
}

572
void ChannelImpl::regenerate(TensorInfo* dest) {
573
    RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen);
574
    if (dest->evict_type == EvictType::DROP) {
575
        recompute(dest->producer);
576 577
    } else if (dest->evict_type == EvictType::SWAP) {
        produce_tensor(dest, Tensor::make(dest->h_value));
578
    }
579
    RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen);
580 581
}

582 583 584
void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
    using namespace ranges;
    using namespace ranges::views;
585
    auto& state = get_worker_state();
586
    bool profiling_device = Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
587 588 589 590 591 592 593
    uint64_t apply_id = cmd.id;
    SmallVector<TensorPtr> tensor_inputs;
    if (state.options.enable_dtr_auto_drop) {
        m_dtr.pin(cmd.inputs); 
    }
    for (auto i : cmd.inputs) {
        if (!i->ptr && i->evict_type != EvictType::NONE) {
594 595 596 597
            regenerate(i);
        }
        m_dtr.update_used_time(i);
    }
598 599 600 601
    tensor_inputs.reserve(cmd.inputs.size());
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
602
        // refcnt ++, owners: [i->ptr, tensor_inputs]
603 604
        tensor_inputs.push_back(i->ptr);
    }
605
    RECORD_EVENT(OpExecuteEvent, apply_id);
606
    // Begin profiling operator
607 608 609 610
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
611 612 613
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
614
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
615 616 617
            }
        }
    }
618 619 620 621 622 623 624 625 626
    for (auto* input: cmd.inputs) {
        auto input_id = input->id;
        RECORD_EVENT(OpInputEvent, input_id);
        RECORD_EVENT(TensorUsageEvent, input_id);
        RECORD_EVENT(OpInputFinishEvent, input_id);
    }
    // Fused by command buffer. @see: CommandBuffer::fuse_del
    // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
    // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
627
    for (auto* del : cmd.dels) {
628 629 630 631
        // refcnt --, owners: [tensor_inputs]
        // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
        uint64_t del_id = del->id;
        RECORD_EVENT(OpDelEvent, del_id);
632
        free(del);
633
        RECORD_EVENT(OpDelFinishEvent, del_id);
634
    }
635 636 637 638 639
    // Before wait
    //TODO: split operator wait and execute so that OpWait could be corrected recorded.
    // Before execute
    for (auto&& [device, kernel_id]: kernels) {
        RECORD_EVENT(KernelExecuteEvent, apply_id, kernel_id, Timer::record_event(device));
640
    }
641
    if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
642 643
        auto_evict();
    }
644 645 646
    // Apply op
    // Here std::move is REQUIRED for removing duplicated references.
    auto tensor_outputs = OpDef::apply_on_physical_tensor(
647
        *cmd.op, std::move(tensor_inputs));
648
    // After execute
649 650
    for (auto&& [device, kernel_id]: kernels) {
        RECORD_EVENT(KernelExecuteFinishEvent, apply_id, kernel_id, Timer::record_event(device));
651 652 653 654 655
    }
    // End profiling operator
    mgb_assert(tensor_outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < tensor_outputs.size(); ++i) {
        auto output = cmd.outputs[i];
656 657 658 659 660 661 662 663
        if (output == nullptr) {
            RECORD_EVENT(OpOutputEvent, 0);
            RECORD_EVENT(OpOutputFinishEvent, 0);
        } else if (output->ptr != nullptr) {
            RECORD_EVENT(OpOutputEvent, output->id);
            RECORD_EVENT(OpOutputFinishEvent, output->id);
        } else {
            RECORD_EVENT(OpOutputEvent, output->id);
664
            produce_tensor(output, tensor_outputs[i]);
665 666
            RECORD_EVENT(OpOutputFinishEvent, output->id);
            sample_on_device(output->desc.comp_node, false);
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
        for (auto i : tensor_outputs) {
            estimate_compute_time += i->blob()->size();
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
        m_dtr.unpin(cmd.inputs);
    }
686 687
    RECORD_EVENT(OpExecuteFinishEvent, apply_id);
    // End profiling operator
688 689 690 691 692 693
}

void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
    auto& state = get_worker_state();
    do_apply_op(ApplyOp{path->id, path->op, path->inputs, path->outputs, {}});
    for (size_t i = 0;i < path->outputs.size();i ++) {
694 695 696 697
        auto&& o = path->outputs[i];
        if (o) {
            o->recompute_times ++;
            if (!o->ptr) {
698
                if (state.options.enable_dtr_auto_drop) {
699 700 701 702
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
703
    }
704 705 706
}

void ChannelImpl::auto_evict() {
707
    auto& state = get_worker_state();
708 709 710 711
    if (!m_dtr.comp_node.valid()) {
        return;
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
712
    while (current_memory > state.options.dtr_eviction_threshold) {
713
        RECORD_EVENT(AutoEvictEvent);
714
        sample_on_device(m_dtr.comp_node, false);
715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
        auto best = m_dtr.find_best_tensor();
        if (!best) {
            if (!m_dtr.warn_printed) {
                m_dtr.warn_printed = true;
                mgb_log_warn("No tensors on %s can be evicted automatically "
                             "when memory usage is %.0lfMB. Maybe memory "
                             "budget is too small.",
                              m_dtr.comp_node.to_string().c_str(),
                              current_memory / 1024.0 / 1024.0);
            }
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
733
        }
734
        sample_on_device(m_dtr.comp_node, false);
735
        RECORD_EVENT(AutoEvictFinishEvent);
736 737 738
    }
}

739 740 741
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
    for (auto* user: users) {
742 743 744
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
        for (auto* output: outputs) {
745 746 747 748
        // When a `ComputePath` is detach from it's input,
        // there is no need to reserve it,
        // so we detach all output of this path
        // to decrease it's `ref_cnt` to zero.
749 750 751 752 753
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
754 755 756
            for (auto* input: inputs) {
                input->ref_cnt --;
            }
757
        }
758
        // now user is dead
759
    }
760
    mgb_assert(dest->users.empty(), "ComputePath leaking");
761 762
}

763 764 765 766
bool ChannelImpl::check_available() {
    return !m_closed;
}

767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    m_buffer.flush();
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
    RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
    bool require_host = prop == TensorProp::HostValue;
    bool value_fetching = false;
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
        if (require_host) {
            if (info->ptr && info->ptr->value_fetched()) {
                return true;
            }
            if (!value_fetching) {
                m_buffer.enqueue(GetValue{info});
                value_fetching = true;
            }
            return false;
787
        } else {
788
            return static_cast<bool>(info->ptr);
789
        }
790 791 792 793 794
    });
    RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr);
    if (m_waitee != nullptr) {
        mgb_assert(m_waitee == info, "waitee mismatch");
        m_waitee = nullptr;
795
    }
796 797 798 799 800 801 802 803
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
        m_waitee = nullptr;
        RECORD_EVENT(TensorNotifyPropEvent, info->id);
        m_cv.notify_all();
804
    }
805 806 807 808 809 810 811 812
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
    for (auto* handle: m_valid_handle) {
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
    //TODO: valid_tensors.insert({info, info->status});
813
    }
814
    return valid_tensors;
815 816
}

817
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
818 819
    using namespace ranges;
    using namespace ranges::views;
820
    auto& state = get_worker_state();
821
    auto& options = state.options;
822
    //TODO: remove std::visit for support osx 10.12
823 824
    auto cmd_visitor = [&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
825
            if constexpr (std::is_same_v<T, Put>) {
826
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Put);
827 828
                auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
                produce_tensor(cmd.dest, std::move(value));
829 830
                RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put);
                sample_on_device(cmd.dest->desc.comp_node, false);
831
            } else if constexpr (std::is_same_v<T, ApplyOp>) {
832 833 834 835
                do_apply_op(cmd);
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
836 837
                        continue;
                    }
838
                    if (state.options.enable_dtr_auto_drop) {
839
                        output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
840 841
                    }
                }
842 843 844 845 846 847
                if (state.options.enable_drop && state.options.record_computing_path) {
                    auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                        auto& input = std::get<0>(tuple2);
                        auto& output = std::get<1>(tuple2);
                        if (!input->ptr || !output->ptr) {
                            return false;
848
                        }
849 850
                        return input->ptr->blob()->storage() == output->ptr->blob()->storage();
                    };
851 852 853 854 855 856 857
                    // FIXME: do not use opname as identifier
                    auto get_name = [](const OpDef& opdef) {
                        if (auto attr = opdef.try_cast_final<OprAttr>()) {
                            return attr->type.c_str();
                        }
                        return opdef.dyn_typeinfo()->name;
                    };
858 859 860 861 862 863 864

                    auto is_cross_cn = [comp_node=m_dtr.comp_node](TensorInfo* info){
                        return info->desc.comp_node != comp_node;
                    };

                    bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                    bool inplace = any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);
865

866 867
                    if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                        TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs);
868 869
                        size_t detach_cnt = 0;
                        for (auto output : cmd.outputs) {
870
                            if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
871 872 873 874 875 876 877 878
                                output->detach_producer();
                                detach_cnt ++;
                            }
                        }
                        for (auto input : cmd.inputs) {
                            input->ref_cnt -= detach_cnt;
                        }
                    }
879 880
                }
            } else if constexpr (std::is_same_v<T, Del>) {
881 882 883
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Del);
                CompNode device = cmd.dest->desc.comp_node;
                uint64_t tensor_id = cmd.dest->id;
884
                free(cmd.dest);
885 886
                RECORD_EVENT(TensorCommandFinishEvent, tensor_id, TensorCommandFinishEvent::Del);
                sample_on_device(device, false);
887
            } else if constexpr (std::is_same_v<T, GetValue>) {
888
                imperative_log_profile_begin("GetValue");
889 890 891
                if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                    regenerate(cmd.dest);
                }
892
                mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
893 894
                cmd.dest->ptr->fetch_value();
                MGB_LOCK_GUARD(m_mutex);
895
                notify_tensor_unsafe(cmd.dest);
896
                imperative_log_profile_end("GetValue");
897
            } else if constexpr (std::is_same_v<T, SwapIn>) {
898
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapIn);
899
                produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
900 901
                RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapIn);
                sample_on_device(cmd.dest->desc.comp_node, false);
902
            } else if constexpr (std::is_same_v<T, SwapOut>) {
903
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapOut);
904
                cmd.dest->h_value = cmd.dest->ptr->get_value();
905 906
                if (cmd.dest->evict_type == EvictType::NONE) {
                    cmd.dest->evict_type = EvictType::SWAP;
907 908
                    cmd.dest->status = TensorInfo::Swapped;
                    release_tensor(cmd.dest);
909
                }
910 911
                RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapOut);
                sample_on_device(cmd.dest->desc.comp_node, false);
912
            } else if constexpr (std::is_same_v<T, Drop>) {
913
                RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Drop);
914
                do_drop(cmd.dest, true);
915
                RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Drop);
916
            } else if constexpr (std::is_same_v<T, SetOption>) {
917
                options.set_option(cmd.key, cmd.value);
918
            } else if constexpr (std::is_same_v<T, StartProfile>) {
919
                RECORD_EVENT(StartProfileEvent);
920
                CompNode::sync_all();
921 922 923 924 925 926 927 928 929 930 931 932 933
                for (auto* info: cmd.capture_tensors) {
                    RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                    if (info->status == TensorInfo::Produced) {
                        // TODO: handle swap/drop
                        RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
                    }
                }
                CompNode::foreach([&](CompNode device){
                    if (Profiler::get_option("sample_rate", 0)) {
                        sample_on_device(device, true);
                    }
                });
                RECORD_EVENT(StartProfileFinishEvent);
934
            } else if constexpr (std::is_same_v<T, StopProfile>) {
935 936 937 938 939 940 941
                RECORD_EVENT(StopProfileEvent);
                for (auto* info: cmd.escape_tensors) {
                    bool has_value = info->status == TensorInfo::Produced;
                    if (has_value) {
                        RECORD_EVENT(TensorReleaseEvent, info->id);
                    }
                    RECORD_EVENT(TensorEraseEvent, info->id);
942
                }
943 944 945
                CompNode::foreach([&](CompNode device){
                    if (Profiler::get_option("sample_rate", 0)) {
                        sample_on_device(device, true);
946
                    }
947 948
                });
                RECORD_EVENT(StopProfileFinishEvent);
949
            } else if constexpr (std::is_same_v<T, PushScope>) {
950
                RECORD_EVENT(ScopeEvent, cmd.scope_name);
951
            } else if constexpr (std::is_same_v<T, PopScope>) {
952
                RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
953
            } else {
954
                static_assert(!std::is_same_v<T, T>);
955
            }
956
    };
957
    std::visit([&](const auto& cmd){
958
        using T = std::decay_t<decltype(cmd)>;
959
        if (!options.catch_worker_execption) {
960 961 962 963 964
            cmd_visitor(cmd);
            return;
        }
        try {
            cmd_visitor(cmd);
965 966
        } catch (...) {
            MGB_LOCK_GUARD(m_mutex);
967 968 969 970 971 972 973
            if constexpr (std::is_same_v<T, ApplyOp>) {
                for (auto oup : cmd.outputs) {
                    oup->invalid = true;
                }
            } else if constexpr (std::is_same_v<T, Put>) {
                cmd.dest->invalid = true;
            }
974
            m_worker_exc = std::current_exception();
975 976 977 978
            RECORD_EVENT(WorkerExceptionEvent);
            if (m_waitee) {
                notify_tensor_unsafe(m_waitee);
            }
979
        }
980
    }, icmd.second);
981 982 983 984
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
985 986
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
987 988 989 990 991
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
        std::rethrow_exception(exc);
    }
}
992 993 994 995 996

void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
    if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
        return;
    }
997
    // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
998 999 1000 1001 1002
    m_commands.push_back(std::move(cmd));
    auto flush_pos = flush_pos_for(m_commands.back());
    flush(flush_pos);
}

1003 1004 1005 1006
void ChannelImpl::CommandBuffer::flush() {
    flush(m_commands.end());
}

1007
void ChannelImpl::CommandBuffer::flush(Handle pos) {
1008
    auto& state = m_owner->get_channel_state();
1009
    for (auto iter = m_commands.begin(); iter != pos; ++iter) {
1010 1011 1012 1013
        if (Profiler::is_profiling()) {
            mgb_log_debug("%s Flushed", to_string(*iter).c_str());
        }
        m_owner->m_worker.add_task(IdentifiedCommand{Profiler::next_id(), std::move(*iter)});
1014 1015 1016 1017 1018
    }
    m_commands.erase(m_commands.begin(), pos);
}

auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
1019
    auto& state = m_owner->get_channel_state();
1020
    return std::visit([this, &state](const auto& cmd) {
1021 1022 1023 1024 1025 1026 1027
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, ApplyOp>) {
            auto* op_type = cmd.op->dyn_typeinfo();
            if (op_type == RemoteRecv::typeinfo() ||
                op_type == RemoteSend::typeinfo() ||
                op_type == CollectiveComm::typeinfo() ||
                op_type == opr::InputCallback::typeinfo() ||
1028
                op_type == opr::OutputCallback::typeinfo()) {
1029 1030 1031 1032 1033
                return m_commands.end();
            }
        } else if constexpr (std::is_same_v<T, GetValue>) {
            return m_commands.end();
        }
1034
        size_t buffer_length = state.options.buffer_length;
1035 1036
        if (m_commands.size() > buffer_length) {
            return m_commands.begin() + (m_commands.size() - buffer_length);
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
        }
        return m_commands.begin();
    }, cmd);
}

/**
 * 1. Find ApplyOp(dest) in buffered commands
 * 2. Check if there are other usages between ApplyOp and Del, return false if not
 * 3. Fuse Del into ApplyOp, return true
 */
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
    auto* dest = cmd.dest;
    // TODO: eliminate Puts
    auto begin = m_commands.begin(), end = m_commands.end();
    auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
        if (auto* apply = std::get_if<ApplyOp>(&cmd)) {
            return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
        }
        return false;
    });
    if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
        return false;
    }
1060
    // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
    std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
    return true;
}

auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
        -> Handle {
    auto found = range[1];
    for (auto iter = range[0]; iter != range[1]; ++iter) {
        std::visit([&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
                               dest) > 0) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, GetValue>) {
                if (cmd.dest == dest) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, SwapIn> ||
                    std::is_same_v<T, SwapOut> ||
                    std::is_same_v<T, Drop>) {
                //TODO: ignore swap-like commands, just remove them from buffer
                if (cmd.dest == dest) {
                    found = iter;
                }
            }
        }, *iter);
    };
    return found;
}

auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
        -> Handle {
    return std::find_if(range[0], range[1], [dest](auto& cmd) {
        return std::visit([dest](const auto& cmd){
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
            } else if constexpr (std::is_same_v<T, Put>) {
                return cmd.dest == dest;
            }
            return false;
        }, cmd);
    });
}
1107

1108
void ChannelImpl::start_profile() {
1109
    mgb_assert(check_available(), "Channel already closed");
1110 1111 1112 1113
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
        m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
    }
1114 1115
}

1116
void ChannelImpl::stop_profile() {
1117
    mgb_assert(check_available(), "Channel already closed");
1118
    m_buffer.flush();
1119 1120 1121 1122
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
        m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
    }
1123 1124 1125
}

void ChannelImpl::push_scope(std::string name) {
1126
    mgb_assert(check_available(), "Channel already closed");
1127
    auto& state = get_channel_state();
1128
    state.scopes.push(name);
1129
    RECORD_EVENT(ScopeEvent, name);
1130
    m_buffer.enqueue(PushScope{name});
1131 1132 1133
}

void ChannelImpl::pop_scope(std::string name) {
1134
    mgb_assert(check_available(), "Channel already closed");
1135
    auto& state = get_channel_state();
1136
    state.scopes.pop(name);
1137
    RECORD_EVENT(ScopeFinishEvent, name);
1138
    m_buffer.enqueue(PopScope{name});
1139 1140
}

1141 1142 1143 1144 1145 1146 1147 1148
void ChannelImpl::assert_in_channel() {
    mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread");
}

void ChannelImpl::assert_in_worker() {
    mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread");
}

1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
void ChannelImpl::sample_on_device(CompNode device, bool force) {
    if (!force) {
        thread_local int last_sample_id = 0;
        int sample_rate = Profiler::is_profiling() ? Profiler::get_option("sample_rate", 0) : 0;
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
    RECORD_EVENT(SampleDeviceEvent, device);
    auto [total, free] = device.get_mem_status_bytes();
    RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
}

1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
    }
}

void ChannelImpl::DynamicSublinear::unpin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->unpin();
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() {
    double min_msps = -1;
    TensorInfo* best = nullptr;
    for (auto i : candidates) {
        if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
            double neighbor_cost = estimate_neighbor_cost(i);
            size_t begin_ptr = reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(begin_ptr, begin_ptr + i->ptr->blob()->size());
            double free_mem = side_info.first + side_info.second;
            double msps = i->eval_func(neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
    }
    return best;
}

void ChannelImpl::DynamicSublinear::merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y) {
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(std::shared_ptr<DsuNode>& x) {
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
    candidates.insert(ptr);
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
    candidates.erase(ptr);
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}