interpreter_impl.cpp 56.8 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21
#include "megbrain/imperative/ops/utility.h"
22 23
#include "megbrain/imperative/utils/to_string.h"

24
#include "../blob_manager_impl.h"
25 26 27
#include "../event_pool.h"
#include "../op_trait.h"

28 29 30 31 32
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

33
namespace {
M
Megvii Engine Team 已提交
34 35 36 37 38 39 40 41
auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
    SmallVector<uint64_t> tid;
    for (auto* ptinfo : tinfo) {
        tid.push_back(ptinfo->id);
    }
    return tid;
};
}  // namespace
42

43
namespace mgb {
M
Megvii Engine Team 已提交
44
using namespace profiler;
45 46
}

47 48 49 50 51
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
52 53 54 55 56 57 58

namespace mgb {

/**
 * USAGE
 *
 *   header:
59
 *     namespace mgb { void imperative_log_profile(const char* message); }
60 61 62 63 64
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
65
SYMBOL_EXPORT
66
void imperative_log_profile_begin(const char* message) {
67
    MGB_RECORD_EVENT(CustomEvent, std::string{message});
68 69
}

70
SYMBOL_EXPORT
71
void imperative_log_profile_end(const char* message) {
72
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
73 74
}

75
SYMBOL_EXPORT
M
Megvii Engine Team 已提交
76
void imperative_log_profile(const char* message) {
77 78 79 80
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

81 82 83 84
SYMBOL_EXPORT
void imperative_log_profile_begin(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
    MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
M
Megvii Engine Team 已提交
85 86
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
87 88 89 90 91
}

SYMBOL_EXPORT
void imperative_log_profile_end(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
M
Megvii Engine Team 已提交
92 93
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
94 95 96
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
}

M
Megvii Engine Team 已提交
97
}  // namespace mgb
98

99 100 101 102 103 104 105 106 107 108 109 110 111 112
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

113 114 115 116 117 118 119 120 121 122
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
    OpDef::set_allocator([&](CompNode device, size_t size) {
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
    });
}

123
// Do not use m_xxx_state directly
124 125 126
#define m_channel_state
#define m_worker_state

127 128 129 130 131 132 133 134 135
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

136
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
137
    MGB_LOCK_GUARD(m_spin);
138
    mgb_assert(check_available(), "Channel already closed");
139
    auto& state = get_channel_state();
140
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
141
    auto info = put_impl(value, no_cache);
M
Megvii Engine Team 已提交
142
    return reinterpret_cast<Handle>(info);
143 144 145
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
146 147 148 149 150
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
151
    auto info = alloc();
152
    init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
153
    info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
154
    info->h_value = value;
155
    m_buffer.enqueue(Put{info, value, no_cache});
156
    if (m_async_level == 0) {
157
        sync_impl();
158
        info->desc.comp_node.sync();
159 160
        auto err = info->desc.comp_node.check_async_error();
        mgb_assert(!err, "%s", err->what());
161
    }
162 163 164
    return info;
}

165
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
166
    MGB_LOCK_GUARD(m_spin);
167
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
168
    return reinterpret_cast<Handle>(put_impl(data, hvalue));
169
}
M
Megvii Engine Team 已提交
170 171
TensorInfo* ChannelImpl::put_impl(
        const DeviceTensorND& data, const HostTensorND& hvalue) {
172
    auto& state = get_channel_state();
173
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
M
Megvii Engine Team 已提交
174
    auto info = alloc();
175
    MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
176
    init(info, {data.layout(), data.comp_node()});
177
    info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
178
    info->ptr = Tensor::make(data, hvalue);
M
Megvii Engine Team 已提交
179 180 181
    MGB_RECORD_EVENT(
            TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
            data.raw_ptr());
182
    info->status = TensorInfo::Produced;
183
    MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
M
Megvii Engine Team 已提交
184 185 186
    return info;
}

187
void ChannelImpl::del(Handle handle) {
188
    MGB_LOCK_GUARD(m_spin);
M
Megvii Engine Team 已提交
189
    if (!check_available()) {
190 191
        return;
    }
192 193 194 195
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
196 197 198 199
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
    m_buffer.enqueue(Del{info});
200 201
}

202
void ChannelImpl::drop(Handle handle) {
203
    MGB_LOCK_GUARD(m_spin);
204
    mgb_assert(check_available(), "Channel already closed");
205 206
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
M
Megvii Engine Team 已提交
207 208
        mgb_assert(
                m_valid_handle.find(handle) != m_valid_handle.end(),
209
                "invalid handle: %p", handle);
210 211
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(Drop{info});
212 213 214
    }
}

215
void ChannelImpl::dispatch_default_cpu(
M
Megvii Engine Team 已提交
216
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
217 218
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
219
    auto& state = get_channel_state();
220 221

    auto name = op->trait()->make_name(*op);
222
    auto _ = StackManager::Guard(name, &state.stack_manager);
223

M
Megvii Engine Team 已提交
224 225
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
226
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
227

228 229 230
    SmallVector<DeviceTensorND> input_tensornds;
    input_tensornds.reserve(input_descs.size());
    CompNode output_cn;
231 232
    {
        MGB_LOCK_GUARD(m_mutex);
233
        for (auto&& info : input_infos) {
234
            auto input_cn = info->desc.comp_node;
235
            if (!output_cn.valid()) {
236 237 238 239 240 241
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
M
Megvii Engine Team 已提交
242 243
                input_tensornds.emplace_back(
                        info->ptr->get_value().proxy_to_default_cpu());
244
            } else {
245
                // We assign h_value before drop ptr
246 247
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
248 249 250 251 252 253 254 255 256 257 258
            }
        }
    }

    outputs->reserve(output_descs.size());
    SmallVector<DeviceTensorND> output_tensornds;
    output_tensornds.reserve(output_descs.size());
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
M
Megvii Engine Team 已提交
259 260
        output_tensornds.emplace_back(
                HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
261 262
    }

263
    uint64_t op_id = Profiler::next_id();
264

265 266 267 268 269
    OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);

    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());
    for (auto&& tensornd : output_tensornds) {
M
Megvii Engine Team 已提交
270 271
        HostTensorND host_tensornd =
                HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
272
        // use `put` for consistency
273
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
274
        mgb_assert(info->desc.layout.ndim != 0);
275
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
276
        outputs->push_back(reinterpret_cast<Handle>(info));
277
    }
M
Megvii Engine Team 已提交
278
    auto op_info_getter = [op] {
279 280
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
M
Megvii Engine Team 已提交
281
        for (auto&& [key, value] : props) {
282 283 284 285
            op_info[key] = value;
        }
        return op_info;
    };
M
Megvii Engine Team 已提交
286 287 288
    MGB_RECORD_EVENT(
            OpDispatchEvent, op_id, name, op_info_getter, tinfo_to_tid(input_infos),
            tinfo_to_tid(output_infos), state.stack_manager.dump());
289
}
290

291
void ChannelImpl::dispatch_kernel(
M
Megvii Engine Team 已提交
292
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
293 294
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
295
    auto& state = get_channel_state();
296 297 298
    auto& options = state.options;

    auto name = op->trait()->make_name(*op);
M
Megvii Engine Team 已提交
299
    auto _ = StackManager::Guard{name, &state.stack_manager};
300

M
Megvii Engine Team 已提交
301 302
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
303
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
304

305
    ApplyOp cmd{Profiler::next_id(), std::move(op)};
306
    cmd.inputs = std::move(input_infos);
307
    cmd.outputs.reserve(output_descs.size());
308
    outputs->reserve(output_descs.size());
309 310
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
311
        auto info = alloc();
312
        init(info, desc);
313 314 315
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
M
Megvii Engine Team 已提交
316
                                    .proxy_to_comp_node(desc.comp_node);
317
        }
318
        cmd.outputs.push_back(info);
M
Megvii Engine Team 已提交
319
        outputs->push_back(reinterpret_cast<Handle>(info));
320
    }
M
Megvii Engine Team 已提交
321
    auto op_info_getter = [op = cmd.op] {
322 323
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
M
Megvii Engine Team 已提交
324
        for (auto&& [key, value] : props) {
325 326 327 328
            op_info[key] = value;
        }
        return op_info;
    };
M
Megvii Engine Team 已提交
329 330 331
    MGB_RECORD_EVENT(
            OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
            tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
332
    m_buffer.enqueue(std::move(cmd));
333
    if (!validated && options.async_level == 1) {
334
        sync_impl();
335
    } else if (options.async_level == 0) {
336
        sync_impl();
337
        // check device error
338
        for (auto&& oup : *outputs) {
339 340
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
341 342
            auto err = info->ptr->comp_node().check_async_error();
            mgb_assert(!err, "%s", err->what());
343
        }
344
    }
345 346 347
}

SmallVector<Handle> ChannelImpl::apply_op(
M
Megvii Engine Team 已提交
348
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
349
    MGB_LOCK_GUARD(m_spin);
350
    mgb_assert(check_available(), "Channel already closed");
351 352 353 354
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
M
Megvii Engine Team 已提交
355
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
356
    auto& state = get_channel_state();
357
    for (auto i : inputs) {
M
Megvii Engine Team 已提交
358 359 360
        mgb_assert(
                m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
                i);
361 362 363 364 365 366 367 368 369
    }
    SmallVector<TensorInfo*> input_infos;
    input_infos.reserve(inputs.size());
    SmallVector<LogicalTensorDesc> input_descs;
    input_descs.reserve(inputs.size());
    {
        MGB_LOCK_GUARD(m_mutex);
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
M
Megvii Engine Team 已提交
370 371 372
            mgb_assert(
                    !info->invalid,
                    "an input tensor is unusable due to previous error");
373 374 375 376 377 378
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
379
    DispatchMode dispatch_mode = state.options.enable_host_compute
M
Megvii Engine Team 已提交
380 381
                                       ? OpDef::decide_dispatch_mode(*op, input_descs)
                                       : DispatchMode::KERNEL;
382
    switch (dispatch_mode) {
383 384 385 386 387 388 389 390 391
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
392 393 394
    return outputs;
}

395
HostTensorND ChannelImpl::get_value(Handle handle) {
396
    MGB_LOCK_GUARD(m_spin);
397
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
398 399 400
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
401
    auto info = reinterpret_cast<TensorInfo*>(handle);
402
    // donnot use info->value_fetched, it's unsafe
403
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
404
    return wait_tensor(info, TensorProp::HostValue)->get_value();
405 406
}

407
TensorShape ChannelImpl::get_shape(Handle handle) {
408
    MGB_LOCK_GUARD(m_spin);
409
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
410 411 412
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
413 414 415 416
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
417
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
418 419 420 421
    mgb_assert(ret.ndim != 0);
    return ret;
}

422
DType ChannelImpl::get_dtype(Handle handle) {
423
    MGB_LOCK_GUARD(m_spin);
424
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
425 426 427
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
428
    auto info = reinterpret_cast<TensorInfo*>(handle);
429
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
430 431 432 433 434
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

435
CompNode ChannelImpl::get_device(Handle handle) {
436
    MGB_LOCK_GUARD(m_spin);
437
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
438 439 440
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
441
    auto info = reinterpret_cast<TensorInfo*>(handle);
442
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
443 444 445 446 447
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

448
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
449
    MGB_LOCK_GUARD(m_spin);
450
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
451 452 453
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
454
    auto info = reinterpret_cast<TensorInfo*>(handle);
455
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
456 457 458
}

void ChannelImpl::sync() {
459
    MGB_LOCK_GUARD(m_spin);
460
    mgb_assert(check_available(), "Channel already closed");
461 462 463 464
    sync_impl();
}

void ChannelImpl::sync_impl() {
465
    m_buffer.flush();
466 467 468 469 470 471
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
472
    MGB_LOCK_GUARD(m_spin);
473 474 475 476
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
M
Megvii Engine Team 已提交
477
    for (auto* handle : valid_handles) {
478
        del_impl(handle);
479 480 481
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
482
    sync_impl();
483
    m_closed = true;
484 485
}

486
size_t ChannelImpl::get_option(std::string name) {
487
    MGB_LOCK_GUARD(m_spin);
488
    mgb_assert(check_available(), "Channel already closed");
489 490
    auto& state = get_channel_state();
    return state.options.get_option(name);
491 492
}

493
void ChannelImpl::set_option(std::string name, size_t value) {
494
    MGB_LOCK_GUARD(m_spin);
495
    mgb_assert(check_available(), "Channel already closed");
496 497
    auto& state = get_channel_state();
    state.options.set_option(name, value);
498
    m_buffer.enqueue(SetOption{name, value});
499 500
}

501 502 503 504 505 506
void ChannelImpl::clear_candidates() {
    MGB_LOCK_GUARD(m_spin);
    mgb_assert(check_available(), "Channel already closed");
    m_dtr.candidates.clear();
}

507
TensorInfo* ChannelImpl::alloc() {
508
    auto& state = get_channel_state();
M
Megvii Engine Team 已提交
509
    auto info = [this] {
510 511 512 513 514
        MGB_LOCK_GUARD(m_mutex);
        return m_pool.alloc();
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
515
        size_t tensor_id = state.stack_manager.current()->next_id("tensor");
M
Megvii Engine Team 已提交
516 517
        info->name =
                state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
518
    }
519
    return info;
520 521
}

522
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
M
Megvii Engine Team 已提交
523
    m_valid_handle.insert(reinterpret_cast<Handle>(info));
524
    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
525 526
    info->status = TensorInfo::Allocated;
    info->desc = std::move(desc);
527 528 529
    info->mem_desc.layout = info->desc.layout;
    info->mem_desc.cn = info->desc.comp_node;
    info->mem_desc.offset = 0;
530 531
}

M
Megvii Engine Team 已提交
532
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
533 534
    if (!ptr->producer) {
        if (user) {
M
Megvii Engine Team 已提交
535 536 537 538
            mgb_log_warn(
                    "the input that produced tensor %p has been deleted, this drop "
                    "operation will be ignored",
                    ptr);
539 540 541 542 543 544 545
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
546
    ptr->status = TensorInfo::Dropped;
547 548 549
    release_tensor(ptr);
}

550
void ChannelImpl::free(TensorInfo* ptr) {
551 552
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
568
    MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
569
    SmallVector<TensorInfo*> inps;
570 571 572 573 574 575 576 577 578 579 580 581 582
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
583
    MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
584 585 586
}

void ChannelImpl::real_free(TensorInfo* ptr) {
587 588
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
589 590 591 592
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
593 594
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
595
        MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
596
    }
597
    MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
598
    ptr->status = TensorInfo::Deleted;
599
    MGB_LOCK_GUARD(m_mutex);
600 601 602
    m_pool.free(ptr);
}

M
Megvii Engine Team 已提交
603
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this) {}
604

605 606 607
ChannelImpl::~ChannelImpl() {
    close();
}
608

609
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
610
    auto& state = get_worker_state();
611
    MGB_LOCK_GUARD(m_mutex);
612
    m_dtr.update_used_time(dest);
M
Megvii Engine Team 已提交
613 614 615
    MGB_RECORD_EVENT(
            TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
            ptr->dev_tensor().raw_ptr());
616 617 618
    // update tensor desc for static infer
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
619
    dest->memory = ptr->blob()->size();
620
    dest->ptr = std::move(ptr);
621
    dest->evict_type = EvictType::NONE;
622
    dest->status = TensorInfo::Produced;
623 624
    if (dest->pinned == 0 &&
        dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
625 626
        m_dtr.insert_candidate(dest);
    }
627
    notify_tensor_unsafe(dest);
628 629
}

630
void ChannelImpl::release_tensor(TensorInfo* dest) {
631
    MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
632 633
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
634 635 636 637
    auto& state = get_worker_state();
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
        m_dtr.erase_candidate(dest);
    }
638 639
}

640
void ChannelImpl::regenerate(TensorInfo* dest) {
641
    if (dest->evict_type == EvictType::DROP) {
M
Megvii Engine Team 已提交
642 643 644 645 646 647
        auto&& path = dest->producer;
        m_apply_stack.push(
                {ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest,
                 "dtr"});
        if (!m_applying)
            flush_apply_stack();
648 649 650
    }
}

651
void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
652 653
    using namespace ranges;
    using namespace ranges::views;
654
    auto& state = get_worker_state();
M
Megvii Engine Team 已提交
655 656
    bool profiling_device =
            Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
657
    uint64_t apply_id = cmd.id;
658 659 660 661 662 663
    struct TensorWithDesc {
        TensorPtr tensor;
        MemoryDesc desc;
    };
    SmallVector<TensorWithDesc> inputs;
    inputs.reserve(cmd.inputs.size());
664 665 666
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
667
        // refcnt ++, owners: [i->ptr, tensor_inputs]
668 669
        // tensor_inputs.push_back(i->ptr);
        inputs.push_back({i->ptr, i->mem_desc});
670
    }
M
Megvii Engine Team 已提交
671 672
    if (state.options.enable_dtr_auto_drop &&
        state.options.dtr_eviction_threshold > 0) {
673 674
        auto_evict(0);
    }
M
Megvii Engine Team 已提交
675 676 677 678 679 680
    auto apply_on_physical_tensor =
            [&](auto&& self, const OpDef& def,
                SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> {
        auto apply_functor = [&](std::shared_ptr<OpDef> op,
                                 SmallVector<TensorWithDesc> inputs,
                                 size_t nr_outputs) -> SmallVector<TensorWithDesc> {
681
            auto opname = op->trait()->make_name(*op);
682
            imperative_log_profile_begin(opname.c_str());
683
            auto outputs = self(self, *op, inputs);
684
            imperative_log_profile_end(opname.c_str());
685 686 687
            return outputs;
        };
        auto const_functor = [&](TensorPtr value) -> TensorWithDesc {
M
Megvii Engine Team 已提交
688 689 690
            return {value, MemoryDesc{
                                   value->layout(), 0, value->comp_node(),
                                   StorageIdentifier::make()}};
691 692 693 694
        };
        if (def.trait()->make_forward_graph) {
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
M
Megvii Engine Team 已提交
695 696 697
            for (auto&& input : inputs) {
                input_descs.push_back(
                        {{{}, input.tensor->dtype()}, input.tensor->comp_node()});
698
            }
699 700 701 702 703 704
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
            auto outputs = forward_graph.apply(inputs, apply_functor, const_functor);
            return outputs;
        }
        SmallVector<TensorPtr> input_tensors;
        SmallVector<MemoryDesc> input_descs;
M
Megvii Engine Team 已提交
705
        for (auto&& input : inputs) {
706 707 708
            input_tensors.push_back(input.tensor);
            input_descs.push_back(input.desc);
        }
M
Megvii Engine Team 已提交
709 710
        auto [output_descs, output_tensors, workspaces] =
                init_output_and_workspace(def, input_tensors, input_descs);
711 712 713 714
        if (!output_descs.empty()) {
            OpDef::execute(def, input_tensors, output_tensors, workspaces);
        } else {
            output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors);
M
Megvii Engine Team 已提交
715 716 717 718
            for (auto&& output_tensor : output_tensors) {
                output_descs.push_back(MemoryDesc{
                        output_tensor->layout(), 0, output_tensor->comp_node(),
                        StorageIdentifier::make()});
719 720
            }
        }
721
        SmallVector<TensorWithDesc> outputs;
M
Megvii Engine Team 已提交
722 723
        for (auto&& [output_tensor, output_desc] :
             ranges::zip_view(output_tensors, output_descs)) {
724 725 726 727
            outputs.push_back({output_tensor, output_desc});
        }
        return outputs;
    };
728
    MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
729
    // Begin profiling operator
730 731 732 733
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
734 735 736
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
737
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
738 739 740
            }
        }
    }
M
Megvii Engine Team 已提交
741
    for (auto* input : cmd.inputs) {
742
        auto input_id = input->id;
743 744 745
        MGB_RECORD_EVENT(OpInputEvent, input_id);
        MGB_RECORD_EVENT(TensorUsageEvent, input_id);
        MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
746 747
    }
    // Fused by command buffer. @see: CommandBuffer::fuse_del
M
Megvii Engine Team 已提交
748 749 750
    // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by
    // tensor_inputs after Del. Note for exprs like 'y = x op x', inplace is unsupported
    // yet but Del would be also fused.
751
    for (auto* del : cmd.dels) {
752
        // refcnt --, owners: [tensor_inputs]
M
Megvii Engine Team 已提交
753 754
        // if it's decreased to 1, would be detected at @see:
        // proxy_graph_detail::apply_on_physical_tensor
755
        uint64_t del_id = del->id;
756
        MGB_RECORD_EVENT(TensorCommandEvent, del_id, TensorCommandKind::Del);
757
        free(del);
758
        MGB_RECORD_EVENT(TensorCommandFinishEvent, del_id, TensorCommandKind::Del);
759
    }
760
    // Before wait
M
Megvii Engine Team 已提交
761
    // TODO: split operator wait and execute so that OpWait could be corrected recorded.
762
    // Before execute
M
Megvii Engine Team 已提交
763
    for (auto&& [device, kernel_id] : kernels) {
764
        MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
M
Megvii Engine Team 已提交
765 766 767
        MGB_RECORD_EVENT_IF(
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
768 769 770
    }
    // Apply op
    // Here std::move is REQUIRED for removing duplicated references.
771
    auto outputs = apply_on_physical_tensor(apply_on_physical_tensor, *cmd.op, inputs);
772
    // After execute
M
Megvii Engine Team 已提交
773 774 775 776
    for (auto&& [device, kernel_id] : kernels) {
        MGB_RECORD_EVENT_IF(
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
777
        MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
778 779
    }
    // End profiling operator
780 781
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
782
        auto output = cmd.outputs[i];
783
        if (output == nullptr) {
784 785
            MGB_RECORD_EVENT(OpOutputEvent, 0);
            MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
786
        } else if (output->ptr != nullptr) {
787 788
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
789
        } else {
790
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
791 792
            produce_tensor(output, outputs[i].tensor);
            output->mem_desc = outputs[i].desc;
793
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
794
            sample_on_device(output->desc.comp_node, false);
795 796 797 798 799 800 801 802
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
803 804
        for (auto i : outputs) {
            estimate_compute_time += i.tensor->blob()->size();
805 806 807 808 809 810 811
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
812
        m_dtr.unpin(cmd.inputs, state);
813
    }
814
    MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
815
    // End profiling operator
816
}
817

818 819
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
820
    auto& state = get_worker_state();
821
    while (!m_apply_stack.empty()) {
M
Megvii Engine Team 已提交
822 823
        auto& [cmd, idx, recomp, reason] =
                m_apply_stack.top();  // cmd.inputs[0~idx-1] is in memory
824 825 826 827 828
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
M
Megvii Engine Team 已提交
829 830
                MGB_RECORD_EVENT(
                        TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
831 832 833
            }
        }
        bool regen = false;
M
Megvii Engine Team 已提交
834
        for (size_t i = idx; i < cmd.inputs.size(); i++) {
835 836 837 838 839 840
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
M
Megvii Engine Team 已提交
841
                regenerate(p);  // add ApplyOp to the stack
842 843 844 845
                regen = true;
                break;
            }
        }
M
Megvii Engine Team 已提交
846 847
        if (regen)
            continue;
848
        // the required input tensors are already in memory
M
Megvii Engine Team 已提交
849 850
        auto [cmd_backup, recomp_backup, reason_backup] =
                std::make_tuple(cmd, recomp, reason);
851
        m_apply_stack.pop();
852
        do_apply_op(cmd_backup, reason_backup);
853
        if (recomp_backup) {
M
Megvii Engine Team 已提交
854 855 856
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, recomp_backup->id,
                    TensorCommandKind::ReGen);
857 858
            for (auto o : cmd_backup.outputs) {
                if (o) {
859 860 861 862
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
863
    }
864
    m_applying = false;
865 866
}

867
bool ChannelImpl::auto_evict(size_t force_num) {
868
    auto& state = get_worker_state();
869
    if (!m_dtr.comp_node.valid()) {
870
        return false;
871 872
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
873
    size_t flag = false;
M
Megvii Engine Team 已提交
874 875 876
    while ((state.options.dtr_eviction_threshold > 0 &&
            current_memory > state.options.dtr_eviction_threshold) ||
           force_num > 0) {
877
        MGB_RECORD_EVENT(AutoEvictEvent);
878
        sample_on_device(m_dtr.comp_node, false);
879
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
880
        if (!best) {
881
            MGB_RECORD_EVENT(AutoEvictFinishEvent);
882 883 884 885
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
886
            if (force_num > 0) {
M
Megvii Engine Team 已提交
887
                force_num--;
888 889
            }
            flag = true;
890 891 892 893
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
894
        }
895
        sample_on_device(m_dtr.comp_node, false);
896
        MGB_RECORD_EVENT(AutoEvictFinishEvent);
897
    }
898
    return flag;
899 900
}

901 902
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
M
Megvii Engine Team 已提交
903
    for (auto* user : users) {
904 905
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
M
Megvii Engine Team 已提交
906 907 908 909 910
        for (auto* output : outputs) {
            // When a `ComputePath` is detach from it's input,
            // there is no need to reserve it,
            // so we detach all output of this path
            // to decrease it's `ref_cnt` to zero.
911 912 913 914 915
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
M
Megvii Engine Team 已提交
916 917
            for (auto* input : inputs) {
                input->ref_cnt--;
918
            }
919
        }
920
        // now user is dead
921
    }
922
    mgb_assert(dest->users.empty(), "ComputePath leaking");
923 924
}

925 926 927 928
bool ChannelImpl::check_available() {
    return !m_closed;
}

929 930 931 932 933 934
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    m_buffer.flush();
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
935
    MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
936
    bool require_host = prop == TensorProp::HostValue;
M
Megvii Engine Team 已提交
937
    auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
938 939
    bool wait_host = !host_available();
    if (require_host && wait_host) {
940 941 942 943 944 945
        // avoid dead lock
        lock.unlock();
        m_buffer.enqueue(GetValue{info});
        m_buffer.flush();
        lock.lock();
    }
946 947
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
948
        return require_host ? host_available() : static_cast<bool>(info->ptr);
949
    });
950
    MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
951
    m_waitee = nullptr;
952 953 954 955
    if (require_host && wait_host) {
        auto err = info->ptr->comp_node().check_async_error();
        mgb_assert(!err, "%s", err->what());
    }
956 957 958 959 960
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
961
        MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
962
        m_cv.notify_all();
963
    }
964 965 966 967
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
M
Megvii Engine Team 已提交
968
    for (auto* handle : m_valid_handle) {
969 970
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
971
    }
972
    return valid_tensors;
973 974
}

975
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
976 977 978 979 980 981
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
M
Megvii Engine Team 已提交
982 983
            if (!evict_suc)
                return false;
984 985 986 987
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
988 989
    reserve_size(x->size());
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
990 991 992 993 994 995
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
        while (!suc) {
            if (!auto_evict(1)) {
                break;
            }
996
            MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
997 998 999 1000 1001
            MGB_CATCH(MemAllocError&, { continue; });
            suc = true;
        }
        if (!suc) {
            set_log_level(pre_level);
M
Megvii Engine Team 已提交
1002 1003 1004
            mgb_log_warn(
                    "reallocating all cuda memory to alleviate fragmentation, the "
                    "performance may be affected");
1005
            set_log_level(LogLevel::NO_LOG);
1006
            imperative_log_profile_begin("defrag");
1007
            BlobManager::inst()->defrag(x->comp_node());
1008
            imperative_log_profile_end("defrag");
1009
            BlobManager::inst()->alloc_direct(x, x->size());
1010 1011 1012 1013 1014
        }
    });
    set_log_level(pre_level);
}

M
Megvii Engine Team 已提交
1015 1016 1017
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>>
ChannelImpl::init_output_and_workspace(
        const OpDef& def, SmallVector<TensorPtr> inputs,
1018
        SmallVector<MemoryDesc> inputs_mem_desc) {
M
Megvii Engine Team 已提交
1019 1020
    auto [outputs_desc, workspaces_desc] =
            OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc);
1021 1022 1023 1024 1025 1026 1027 1028
    if (!outputs_desc.size()) {
        // failed to infer memplan
        return {{}, {}, {}};
    }
    // refine storage id to make it unique
    for (auto&& desc : outputs_desc) {
        if (desc.id->is_sys_alloc()) {
            // TODO: there may be some outputs sharing the same storage id
M
Megvii Engine Team 已提交
1029
            desc.id->id = ++m_storage_id;
1030 1031
        }
    }
1032
    auto& state = get_worker_state();
1033 1034
    auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
        SmallVector<TensorPtr> tensors;
M
Megvii Engine Team 已提交
1035
        for (size_t i = 0; i < desc.size(); i++) {
1036 1037
            if (desc[i].id->is_sys_alloc()) {
                tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
1038
                if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) {
1039
                    alloc_tensor_with_evict(tensors.back()->blob().get());
1040
                }
1041
            } else if (desc[i].id->is_from_other()) {
M
Megvii Engine Team 已提交
1042
                for (size_t j = 0; j < inputs_mem_desc.size(); j++) {
1043
                    if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
M
Megvii Engine Team 已提交
1044 1045
                        tensors.push_back(
                                inputs[j]->sub(desc[i].offset, desc[i].layout));
1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
                        break;
                    }
                }
            } else if (desc[i].id->is_device_ptr()) {
                tensors.push_back(desc[i].id->ptr);
            } else {
                mgb_assert(0, "not implemented");
            }
        }
        return tensors;
    };
1057

1058 1059 1060
    return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)};
}

1061
void ChannelImpl::process_one_task(Command& icmd) {
1062 1063
    using namespace ranges;
    using namespace ranges::views;
1064
    auto& state = get_worker_state();
1065
    auto& options = state.options;
M
Megvii Engine Team 已提交
1066
    // TODO: remove std::visit for support osx 10.12
1067
    auto cmd_visitor = [&](const auto& cmd) {
M
Megvii Engine Team 已提交
1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, Put>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
                                      : Tensor::make(cmd.value);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            produce_tensor(cmd.dest, std::move(value));
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
            sample_on_device(cmd.dest->desc.comp_node, false);
        } else if constexpr (std::is_same_v<T, ApplyOp>) {
            for (auto& i : cmd.inputs) {
                if (i->invalid) {
                    MGB_LOCK_GUARD(m_mutex);
                    for (auto& i : cmd.outputs) {
                        i->invalid = true;
1089
                    }
M
Megvii Engine Team 已提交
1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101
                    return;
                }
            }
            m_apply_stack.push({cmd, 0, nullptr, "cmd"});
            flush_apply_stack();
            for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                auto output = cmd.outputs[i];
                if (output == nullptr) {
                    continue;
                }
                if (state.options.enable_dtr_auto_drop) {
                    output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1102
                }
M
Megvii Engine Team 已提交
1103 1104 1105 1106 1107 1108 1109
            }
            if (state.options.enable_drop && state.options.record_computing_path) {
                auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                    auto& input = std::get<0>(tuple2);
                    auto& output = std::get<1>(tuple2);
                    if (!input->ptr || !output->ptr) {
                        return false;
1110
                    }
M
Megvii Engine Team 已提交
1111 1112 1113 1114 1115 1116 1117
                    return input->ptr->blob()->storage() ==
                           output->ptr->blob()->storage();
                };
                // FIXME: do not use opname as identifier
                auto get_name = [](const OpDef& opdef) {
                    if (auto attr = opdef.try_cast_final<OprAttr>()) {
                        return attr->type.c_str();
1118
                    }
M
Megvii Engine Team 已提交
1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137
                    return opdef.dyn_typeinfo()->name;
                };

                auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
                    return info->desc.comp_node != comp_node;
                };

                bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                bool inplace =
                        any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);

                if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                    TensorInfo::ComputePath::make(
                            cmd.id, cmd.op, cmd.inputs, cmd.outputs);
                    size_t detach_cnt = 0;
                    if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
                        cmd.outputs.size() == 5) {
                        cmd.outputs[0]->detach_producer();  // detach running_mean
                        cmd.outputs[1]->detach_producer();  // detach running_var
1138
                        for (auto input : cmd.inputs) {
M
Megvii Engine Team 已提交
1139
                            input->ref_cnt -= 2;
1140 1141
                        }
                    }
M
Megvii Engine Team 已提交
1142 1143 1144 1145 1146 1147 1148
                    for (auto output : cmd.outputs) {
                        if (output->producer &&
                            !output->size_exceeds_thd(
                                    state.options.dtr_evictee_minimum_size)) {
                            output->detach_producer();
                            detach_cnt++;
                        }
1149
                    }
M
Megvii Engine Team 已提交
1150 1151
                    for (auto input : cmd.inputs) {
                        input->ref_cnt -= detach_cnt;
1152
                    }
1153
                }
1154
            }
M
Megvii Engine Team 已提交
1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170
        } else if constexpr (std::is_same_v<T, Del>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
            CompNode device = cmd.dest->desc.comp_node;
            uint64_t tensor_id = cmd.dest->id;
            free(cmd.dest);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
            sample_on_device(device, false);
        } else if constexpr (std::is_same_v<T, GetValue>) {
            if (cmd.dest->invalid)
                return;
            imperative_log_profile_begin("GetValue");
            if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                regenerate(cmd.dest);
            }
            cmd.dest->ptr->fetch_value();
1171
            MGB_LOCK_GUARD(m_mutex);
M
Megvii Engine Team 已提交
1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188
            notify_tensor_unsafe(cmd.dest);
            imperative_log_profile_end("GetValue");
        } else if constexpr (std::is_same_v<T, Drop>) {
            if (cmd.dest->invalid)
                return;
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
            do_drop(cmd.dest, true);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
        } else if constexpr (std::is_same_v<T, SetOption>) {
            options.set_option(cmd.key, cmd.value);
        } else if constexpr (std::is_same_v<T, StartProfile>) {
            MGB_RECORD_EVENT(StartProfileEvent);
            CompNode::sync_all();
            for (auto* info : cmd.capture_tensors) {
                MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                if (info->status == TensorInfo::Produced) {
1189
                    // TODO: handle drop
M
Megvii Engine Team 已提交
1190 1191 1192
                    MGB_RECORD_EVENT(
                            TensorProduceEvent, info->id, info->desc.layout,
                            info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
1193 1194
                }
            }
M
Megvii Engine Team 已提交
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209
            CompNode::foreach ([&](CompNode device) {
                sample_on_device(device, true);
                MGB_RECORD_EVENT_IF(
                        (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                        Timer::record_device(device));
            });
            MGB_RECORD_EVENT(StartProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, StopProfile>) {
            MGB_RECORD_EVENT(StopProfileEvent);
            for (auto* info : cmd.escape_tensors) {
                bool has_value = info->status == TensorInfo::Produced;
                if (has_value) {
                    MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
                }
                MGB_RECORD_EVENT(TensorEraseEvent, info->id);
1210
            }
M
Megvii Engine Team 已提交
1211 1212 1213 1214 1215 1216 1217 1218 1219
            CompNode::foreach (
                    [&](CompNode device) { sample_on_device(device, true); });
            MGB_RECORD_EVENT(StopProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, PushScope>) {
            MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
        } else if constexpr (std::is_same_v<T, PopScope>) {
            MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
        } else {
            static_assert(!std::is_same_v<T, T>);
1220
        }
M
Megvii Engine Team 已提交
1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247
    };
    std::visit(
            [&](const auto& cmd) {
                using T = std::decay_t<decltype(cmd)>;
                if (!options.catch_worker_execption) {
                    cmd_visitor(cmd);
                    return;
                }
                try {
                    cmd_visitor(cmd);
                } catch (...) {
                    MGB_LOCK_GUARD(m_mutex);
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        for (auto oup : cmd.outputs) {
                            oup->invalid = true;
                        }
                    } else if constexpr (std::is_same_v<T, Put>) {
                        cmd.dest->invalid = true;
                    }
                    m_worker_exc = std::current_exception();
                    MGB_RECORD_EVENT(WorkerExceptionEvent);
                    if (m_waitee) {
                        notify_tensor_unsafe(m_waitee);
                    }
                }
            },
            icmd.data);
1248 1249 1250 1251
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1252 1253
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1254 1255
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1256 1257 1258 1259 1260
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1261 1262
    }
}
1263

1264 1265
void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
    auto& state = m_owner->get_channel_state();
1266 1267 1268
    if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
        return;
    }
M
Megvii Engine Team 已提交
1269 1270
    m_commands.push_back(
            {Profiler::next_id(), std::move(cmd), state.stack_manager.dump()});
1271 1272 1273 1274
    auto flush_pos = flush_pos_for(m_commands.back());
    flush(flush_pos);
}

1275 1276 1277 1278
void ChannelImpl::CommandBuffer::flush() {
    flush(m_commands.end());
}

1279 1280
void ChannelImpl::CommandBuffer::flush(Handle pos) {
    for (auto iter = m_commands.begin(); iter != pos; ++iter) {
1281 1282 1283
        if (Profiler::is_profiling()) {
            mgb_log_debug("%s Flushed", to_string(*iter).c_str());
        }
1284
        m_owner->m_worker.add_task(std::move(*iter));
1285 1286 1287 1288 1289
    }
    m_commands.erase(m_commands.begin(), pos);
}

auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
1290
    auto& state = m_owner->get_channel_state();
M
Megvii Engine Team 已提交
1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312
    return std::visit(
            [this, &state](const auto& cmd) {
                using T = std::decay_t<decltype(cmd)>;
                if constexpr (std::is_same_v<T, ApplyOp>) {
                    auto* op_type = cmd.op->dyn_typeinfo();
                    if (op_type == RemoteRecv::typeinfo() ||
                        op_type == RemoteSend::typeinfo() ||
                        op_type == CollectiveComm::typeinfo() ||
                        op_type == opr::InputCallback::typeinfo() ||
                        op_type == opr::OutputCallback::typeinfo()) {
                        return m_commands.end();
                    }
                } else if constexpr (std::is_same_v<T, GetValue>) {
                    return m_commands.end();
                }
                size_t buffer_length = state.options.buffer_length;
                if (m_commands.size() > buffer_length) {
                    return m_commands.begin() + (m_commands.size() - buffer_length);
                }
                return m_commands.begin();
            },
            cmd.data);
1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
}

/**
 * 1. Find ApplyOp(dest) in buffered commands
 * 2. Check if there are other usages between ApplyOp and Del, return false if not
 * 3. Fuse Del into ApplyOp, return true
 */
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
    auto* dest = cmd.dest;
    // TODO: eliminate Puts
    auto begin = m_commands.begin(), end = m_commands.end();
M
Megvii Engine Team 已提交
1324
    auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd) {
1325
        if (auto* apply = std::get_if<ApplyOp>(&cmd.data)) {
1326 1327 1328 1329
            return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
        }
        return false;
    });
M
Megvii Engine Team 已提交
1330
    if (apply_iter == end || find_last_usage(dest, {apply_iter + 1, end}) != end) {
1331 1332
        return false;
    }
1333
    std::get<ApplyOp>(apply_iter->data).dels.push_back(dest);
1334 1335 1336 1337 1338 1339 1340
    return true;
}

auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
        -> Handle {
    auto found = range[1];
    for (auto iter = range[0]; iter != range[1]; ++iter) {
M
Megvii Engine Team 已提交
1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352
        std::visit(
                [&](const auto& cmd) {
                    using T = std::decay_t<decltype(cmd)>;
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        if (std::count(cmd.inputs.begin(), cmd.inputs.end(), dest) >
                            0) {
                            found = iter;
                        }
                    } else if constexpr (std::is_same_v<T, GetValue>) {
                        if (cmd.dest == dest) {
                            found = iter;
                        }
1353
                    } else if constexpr (std::is_same_v<T, Drop>) {
M
Megvii Engine Team 已提交
1354 1355 1356 1357 1358 1359 1360
                        // TODO: ignore swap-like commands, just remove them from buffer
                        if (cmd.dest == dest) {
                            found = iter;
                        }
                    }
                },
                iter->data);
1361 1362 1363 1364
    };
    return found;
}

M
Megvii Engine Team 已提交
1365
auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) -> Handle {
1366
    return std::find_if(range[0], range[1], [dest](auto& cmd) {
M
Megvii Engine Team 已提交
1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379
        return std::visit(
                [dest](const auto& cmd) {
                    using T = std::decay_t<decltype(cmd)>;
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        return std::count(
                                       cmd.outputs.begin(), cmd.outputs.end(), dest) >
                               0;
                    } else if constexpr (std::is_same_v<T, Put>) {
                        return cmd.dest == dest;
                    }
                    return false;
                },
                cmd.data);
1380 1381
    });
}
1382

1383
void ChannelImpl::start_profile() {
1384
    MGB_LOCK_GUARD(m_spin);
1385
    mgb_assert(check_available(), "Channel already closed");
1386 1387 1388 1389
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
        m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
    }
1390 1391
}

1392
void ChannelImpl::stop_profile() {
1393
    MGB_LOCK_GUARD(m_spin);
1394
    mgb_assert(check_available(), "Channel already closed");
1395
    m_buffer.flush();
1396 1397 1398 1399
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
        m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
    }
1400 1401 1402
}

void ChannelImpl::push_scope(std::string name) {
1403
    MGB_LOCK_GUARD(m_spin);
1404
    mgb_assert(check_available(), "Channel already closed");
1405
    auto& state = get_channel_state();
1406
    state.stack_manager.enter(name);
1407
    MGB_RECORD_EVENT(ScopeEvent, name);
1408
    m_buffer.enqueue(PushScope{name});
1409 1410 1411
}

void ChannelImpl::pop_scope(std::string name) {
1412
    MGB_LOCK_GUARD(m_spin);
1413
    mgb_assert(check_available(), "Channel already closed");
1414
    auto& state = get_channel_state();
1415
    state.stack_manager.exit(name);
1416
    MGB_RECORD_EVENT(ScopeFinishEvent, name);
1417
    m_buffer.enqueue(PopScope{name});
1418 1419
}

1420
void ChannelImpl::assert_in_channel() {
M
Megvii Engine Team 已提交
1421 1422 1423
    mgb_assert(
            get_worker_tid() != std::this_thread::get_id(),
            "this method cannot be called in worker thread");
1424 1425 1426
}

void ChannelImpl::assert_in_worker() {
M
Megvii Engine Team 已提交
1427 1428 1429
    mgb_assert(
            get_worker_tid() == std::this_thread::get_id(),
            "this method can only be called in worker thread");
1430 1431
}

1432 1433 1434
void ChannelImpl::sample_on_device(CompNode device, bool force) {
    if (!force) {
        thread_local int last_sample_id = 0;
M
Megvii Engine Team 已提交
1435 1436
        int sample_rate =
                Profiler::is_profiling() ? Profiler::get_option("sample_rate", 0) : 0;
1437 1438 1439 1440
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
1441
    MGB_RECORD_EVENT(SampleDeviceEvent, device);
1442
    auto [total, free] = device.get_mem_status_bytes();
1443
    MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
1444 1445
}

1446 1447 1448
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
1449
        erase_candidate(i);
1450 1451 1452
    }
}

1453 1454
void ChannelImpl::DynamicSublinear::unpin(
        const SmallVector<TensorInfo*>& vec, WorkerState& state) {
1455 1456
    for (auto i : vec) {
        i->unpin();
1457 1458 1459 1460 1461
        if (i->pinned == 0 &&
            i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
            i->cand_index == UINT_MAX) {
            insert_candidate(i);
        }
1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

M
Megvii Engine Team 已提交
1508 1509
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
        bool enable_dtr_sqrt_sampling = false) {
1510 1511 1512
    if (candidates.empty())
        return nullptr;

1513 1514
    double min_msps = -1;
    TensorInfo* best = nullptr;
1515 1516
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
M
Megvii Engine Team 已提交
1517 1518
        while (sz * sz <= candidates.size())
            sz++;
1519
        sz--;
1520 1521 1522
    } else {
        sz = candidates.size();
    }
1523 1524 1525 1526 1527 1528 1529

    size_t ti = rand() % sz;
    for (size_t vi = 0; vi < sz; vi++) {
        if (!enable_dtr_sqrt_sampling) {
            ti = vi;
        }
        auto i = candidates[ti];
1530
        if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
1531
            double neighbor_cost = estimate_neighbor_cost(i);
M
Megvii Engine Team 已提交
1532 1533 1534 1535
            size_t begin_ptr =
                    reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(
                    begin_ptr, begin_ptr + i->ptr->blob()->size());
1536
            double free_mem = side_info.first + side_info.second;
M
Megvii Engine Team 已提交
1537 1538
            double msps = i->eval_func(
                    neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
1539 1540 1541 1542 1543
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1544 1545 1546 1547 1548
        if (enable_dtr_sqrt_sampling) {
            ti += rand() % sz;
            if (ti > candidates.size())
                break;
        }
1549 1550 1551 1552
    }
    return best;
}

M
Megvii Engine Team 已提交
1553 1554
void ChannelImpl::DynamicSublinear::merge(
        std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
1555 1556 1557 1558 1559 1560 1561 1562 1563
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

M
Megvii Engine Team 已提交
1564 1565
std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
        std::shared_ptr<DsuNode>& x) {
1566 1567 1568 1569 1570 1571 1572 1573 1574
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
1575 1576 1577 1578 1579 1580
    // tensor to be inserted must be brand new
    mgb_assert(
            ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
            ptr->cand_index);
    ptr->cand_index = candidates.size();
    candidates.push_back(ptr);
1581 1582 1583 1584 1585 1586
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
1587 1588 1589 1590 1591 1592
    // close dtr will just clear candidates, so nothing to erase
    if (candidates.empty()) {
        ptr->cand_index = UINT_MAX;
        return;
    }
    // some tensors may be erased already, just skip them
1593 1594 1595 1596 1597 1598
    if (ptr->cand_index != UINT_MAX) {
        std::swap(candidates[ptr->cand_index], candidates.back());
        candidates[ptr->cand_index]->cand_index = ptr->cand_index;
        candidates.pop_back();
        ptr->cand_index = UINT_MAX;
    }
1599 1600 1601 1602 1603
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}