interpreter_impl.cpp 54.5 KB
Newer Older
1
#include "./interpreter_impl.h"
2

3 4
#include "range/v3/all.hpp"

5
#include "megbrain/common.h"
6 7
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
8 9
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
10
#include "megbrain/imperative/ops/utility.h"
11 12
#include "megbrain/imperative/utils/to_string.h"

13
#include "../blob_manager_impl.h"
14 15 16
#include "../event_pool.h"
#include "../op_trait.h"

17 18 19 20 21
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

22
namespace {
M
Megvii Engine Team 已提交
23 24 25 26 27 28 29 30
auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
    SmallVector<uint64_t> tid;
    for (auto* ptinfo : tinfo) {
        tid.push_back(ptinfo->id);
    }
    return tid;
};
}  // namespace
31

32
namespace mgb {
M
Megvii Engine Team 已提交
33
using namespace profiler;
34 35
}

36 37 38 39 40
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
41 42 43 44 45 46 47

namespace mgb {

/**
 * USAGE
 *
 *   header:
48
 *     namespace mgb { void imperative_log_profile(const char* message); }
49 50 51 52 53
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
54
SYMBOL_EXPORT
55
void imperative_log_profile_begin(const char* message) {
56
    MGB_RECORD_EVENT(CustomEvent, std::string{message});
57 58
}

59
SYMBOL_EXPORT
60
void imperative_log_profile_end(const char* message) {
61
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
62 63
}

64
SYMBOL_EXPORT
M
Megvii Engine Team 已提交
65
void imperative_log_profile(const char* message) {
66 67 68 69
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

70 71 72 73
SYMBOL_EXPORT
void imperative_log_profile_begin(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
    MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
M
Megvii Engine Team 已提交
74 75
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
76 77 78 79 80
}

SYMBOL_EXPORT
void imperative_log_profile_end(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
M
Megvii Engine Team 已提交
81 82
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
83 84 85
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
}

M
Megvii Engine Team 已提交
86
}  // namespace mgb
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

102 103 104
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
105
    auto custom_allocator = [&](CompNode device, size_t size) {
106 107 108
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
109 110
    };
    OpDef::set_allocator(custom_allocator);
111 112
}

113
// Do not use m_xxx_state directly
114 115 116
#define m_channel_state
#define m_worker_state

117 118 119 120 121 122 123 124 125
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

126
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
127
    MGB_LOCK_GUARD(m_spin);
128
    mgb_assert(check_available(), "Channel already closed");
129 130 131 132 133
    std::optional<StackManager::Guard> guard;
    if (Profiler::is_profiling()) {
        auto& state = get_channel_state();
        guard.emplace("Put", &state.stack_manager);
    }
134
    auto info = put_impl(value, no_cache);
M
Megvii Engine Team 已提交
135
    return reinterpret_cast<Handle>(info);
136 137 138
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
139 140 141 142 143
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
144
    auto info = alloc();
145 146 147 148 149 150
    constexpr int size_threshold = TensorShape::MAX_NDIM;
    init(info, {value.layout(), value.comp_node()});
    if (value.layout().total_nr_elems() <= size_threshold) {
        info->h_value = value;
        info->desc.value = value.proxy_to_default_cpu();
    }
151 152 153 154 155 156 157 158 159 160
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Put{info, value, no_cache},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Put{info, value, no_cache},
        });
    }
161
    if (m_async_level == 0) {
162
        sync_impl();
163
        info->desc.comp_node.sync();
164 165
        auto err = info->desc.comp_node.check_async_error();
        mgb_assert(!err, "%s", err->what());
166
    }
167 168 169
    return info;
}

170
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
171
    MGB_LOCK_GUARD(m_spin);
172
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
173
    return reinterpret_cast<Handle>(put_impl(data, hvalue));
174
}
M
Megvii Engine Team 已提交
175 176
TensorInfo* ChannelImpl::put_impl(
        const DeviceTensorND& data, const HostTensorND& hvalue) {
177 178 179 180 181
    std::optional<StackManager::Guard> guard;
    if (Profiler::is_profiling()) {
        auto& state = get_channel_state();
        guard.emplace("Put", &state.stack_manager);
    }
M
Megvii Engine Team 已提交
182
    auto info = alloc();
183
    MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
184
    constexpr int size_threshold = TensorShape::MAX_NDIM;
185
    init(info, {data.layout(), data.comp_node()});
186 187 188
    if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
        info->desc.value = hvalue.proxy_to_default_cpu();
    }
189
    info->ptr = Tensor::make(data, hvalue);
M
Megvii Engine Team 已提交
190 191 192
    MGB_RECORD_EVENT(
            TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
            data.raw_ptr());
193
    info->status = TensorInfo::Produced;
194
    MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
M
Megvii Engine Team 已提交
195 196 197
    return info;
}

198
void ChannelImpl::del(Handle handle) {
199
    MGB_LOCK_GUARD(m_spin);
M
Megvii Engine Team 已提交
200
    if (!check_available()) {
201 202
        return;
    }
203 204 205 206
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
207 208 209
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
210 211 212 213 214 215 216 217 218 219
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Del{info},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Del{info},
        });
    }
220 221
}

222
void ChannelImpl::drop(Handle handle) {
223
    MGB_LOCK_GUARD(m_spin);
224
    mgb_assert(check_available(), "Channel already closed");
225 226
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
M
Megvii Engine Team 已提交
227 228
        mgb_assert(
                m_valid_handle.find(handle) != m_valid_handle.end(),
229
                "invalid handle: %p", handle);
230
        auto* info = reinterpret_cast<TensorInfo*>(handle);
231 232 233 234 235 236 237 238 239 240
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), Drop{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    Drop{info},
            });
        }
241 242 243
    }
}

244
void ChannelImpl::dispatch_default_cpu(
M
Megvii Engine Team 已提交
245
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
246 247
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
248
    auto& state = get_channel_state();
249

250 251 252 253
    std::optional<StackManager::Guard> guard;
    if (Profiler::is_profiling()) {
        guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
    }
254

M
Megvii Engine Team 已提交
255 256
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
257
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
258

259 260
    SmallVector<DeviceTensorND> input_tensornds;
    CompNode output_cn;
261 262
    {
        MGB_LOCK_GUARD(m_mutex);
263
        for (auto&& info : input_infos) {
264
            auto input_cn = info->desc.comp_node;
265
            if (!output_cn.valid()) {
266 267 268 269 270 271
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
M
Megvii Engine Team 已提交
272 273
                input_tensornds.emplace_back(
                        info->ptr->get_value().proxy_to_default_cpu());
274
            } else {
275
                // We assign h_value before drop ptr
276 277
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
278 279 280 281 282 283 284 285 286
            }
        }
    }

    SmallVector<DeviceTensorND> output_tensornds;
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
M
Megvii Engine Team 已提交
287 288
        output_tensornds.emplace_back(
                HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
289 290
    }

291
    uint64_t op_id = Profiler::next_id();
292

293 294 295 296 297 298 299 300 301
    if (op->trait()->apply_on_device_tensornd) {
        OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
    } else {
        // proxy to apply_on_physical_tensor
        SmallVector<TensorPtr> input_tensors;
        for (auto&& input_tensornd : input_tensornds) {
            input_tensors.push_back(Tensor::make(
                    input_tensornd, HostTensorND::make_proxy(input_tensornd)));
        }
302 303
        auto output_tensors = OpDef::apply_on_physical_tensor(
                *op, input_tensors, output_descs, validated);
304 305 306 307
        for (size_t i = 0; i < output_tensors.size(); ++i) {
            output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
        }
    }
308 309 310

    SmallVector<TensorInfo*> output_infos;
    for (auto&& tensornd : output_tensornds) {
M
Megvii Engine Team 已提交
311 312
        HostTensorND host_tensornd =
                HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
313
        // use `put` for consistency
314
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
315
        mgb_assert(info->desc.layout.ndim != 0);
316
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
317
        outputs->push_back(reinterpret_cast<Handle>(info));
318
    }
M
Megvii Engine Team 已提交
319
    auto op_info_getter = [op] {
320 321
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
M
Megvii Engine Team 已提交
322
        for (auto&& [key, value] : props) {
323 324 325 326
            op_info[key] = value;
        }
        return op_info;
    };
M
Megvii Engine Team 已提交
327
    MGB_RECORD_EVENT(
328 329 330
            OpDispatchEvent, op_id, guard.value().name(), op_info_getter,
            tinfo_to_tid(input_infos), tinfo_to_tid(output_infos),
            state.stack_manager.dump());
331
}
332

333
void ChannelImpl::dispatch_kernel(
M
Megvii Engine Team 已提交
334
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
335 336
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
337
    auto& state = get_channel_state();
338 339
    auto& options = state.options;

340 341 342 343
    std::optional<StackManager::Guard> guard;
    if (Profiler::is_profiling()) {
        guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
    }
344

M
Megvii Engine Team 已提交
345 346
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
347
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
348

349 350 351 352
    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());

    outputs->reserve(output_descs.size());
353 354
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
355
        auto info = alloc();
356
        init(info, std::move(desc));
357 358
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
359 360
            info->h_value = HostTensorND::make_proxy(info->desc.value)
                                    .proxy_to_comp_node(info->desc.comp_node);
361
        }
362
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
363
        outputs->push_back(reinterpret_cast<Handle>(info));
364
    }
365 366 367
    ApplyOp cmd{
            Profiler::next_id(), std::move(op), std::move(input_infos),
            std::move(output_infos), validated};
368
    if (Profiler::is_profiling()) {
369 370 371 372 373 374 375 376
        auto op_info_getter = [op = cmd.op] {
            std::unordered_map<std::string, std::string> op_info;
            auto props = OpDef::props(*op);
            for (auto&& [key, value] : props) {
                op_info[key] = value;
            }
            return op_info;
        };
377
        MGB_RECORD_EVENT(
378 379 380
                OpDispatchEvent, cmd.id, guard.value().name(), op_info_getter,
                tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs),
                state.stack_manager.dump());
381
        m_worker.add_task(
382
                {Profiler::next_id(), std::move(cmd),
383 384 385 386
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
387
                std::move(cmd),
388 389
        });
    }
390
    if (!validated && options.async_level == 1) {
391
        sync_impl();
392
    } else if (options.async_level == 0) {
393
        sync_impl();
394
        // check device error
395
        for (auto&& oup : *outputs) {
396 397
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
398 399
            auto err = info->ptr->comp_node().check_async_error();
            mgb_assert(!err, "%s", err->what());
400
        }
401
    }
402 403 404
}

SmallVector<Handle> ChannelImpl::apply_op(
M
Megvii Engine Team 已提交
405
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
406
    MGB_LOCK_GUARD(m_spin);
407
    mgb_assert(check_available(), "Channel already closed");
408 409 410 411 412 413 414 415 416 417 418
    auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
    if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
        size_t ndim = input->desc.layout.ndim;
        auto& gvs = op->cast_final_safe<GetVarShape>();
        if (gvs.axis == MEGDNN_MAX_NDIM) {
            HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
            DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
            cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
            return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
        }
    }
419 420 421 422
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
M
Megvii Engine Team 已提交
423
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
424
    auto& state = get_channel_state();
425
    for (auto i : inputs) {
M
Megvii Engine Team 已提交
426 427 428
        mgb_assert(
                m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
                i);
429 430 431 432
    }
    SmallVector<TensorInfo*> input_infos;
    SmallVector<LogicalTensorDesc> input_descs;
    {
433
        MGB_LOCK_GUARD(m_info_spin);
434 435
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
M
Megvii Engine Team 已提交
436 437 438
            mgb_assert(
                    !info->invalid,
                    "an input tensor is unusable due to previous error");
439 440 441 442 443 444
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
445
    DispatchMode dispatch_mode = state.options.enable_host_compute
M
Megvii Engine Team 已提交
446 447
                                       ? OpDef::decide_dispatch_mode(*op, input_descs)
                                       : DispatchMode::KERNEL;
448
    switch (dispatch_mode) {
449 450 451 452 453 454 455 456 457
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
458 459 460
    return outputs;
}

461
HostTensorND ChannelImpl::get_value(Handle handle) {
462
    MGB_LOCK_GUARD(m_spin);
463
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
464 465 466
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
467
    auto info = reinterpret_cast<TensorInfo*>(handle);
468
    // donnot use info->value_fetched, it's unsafe
469
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
470
    return wait_tensor(info, TensorProp::HostValue)->get_value();
471 472
}

473
TensorShape ChannelImpl::get_shape(Handle handle) {
474
    MGB_LOCK_GUARD(m_spin);
475
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
476 477 478
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
479 480 481 482
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
483
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
484 485 486 487
    mgb_assert(ret.ndim != 0);
    return ret;
}

488
DType ChannelImpl::get_dtype(Handle handle) {
489
    MGB_LOCK_GUARD(m_spin);
490
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
491 492 493
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
494
    auto info = reinterpret_cast<TensorInfo*>(handle);
495
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
496 497 498 499 500
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

501
CompNode ChannelImpl::get_device(Handle handle) {
502
    MGB_LOCK_GUARD(m_spin);
503
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
504 505 506
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
507
    auto info = reinterpret_cast<TensorInfo*>(handle);
508
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
509 510 511 512 513
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

514
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
515
    MGB_LOCK_GUARD(m_spin);
516
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
517 518 519
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
520
    auto info = reinterpret_cast<TensorInfo*>(handle);
521
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
522 523 524
}

void ChannelImpl::sync() {
525
    MGB_LOCK_GUARD(m_spin);
526
    mgb_assert(check_available(), "Channel already closed");
527 528 529 530
    sync_impl();
}

void ChannelImpl::sync_impl() {
531 532 533 534 535 536
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
537
    MGB_LOCK_GUARD(m_spin);
538 539 540 541
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
M
Megvii Engine Team 已提交
542
    for (auto* handle : valid_handles) {
543
        del_impl(handle);
544 545 546
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
547
    sync_impl();
548
    m_closed = true;
549 550
}

551
size_t ChannelImpl::get_option(std::string name) {
552
    MGB_LOCK_GUARD(m_spin);
553
    mgb_assert(check_available(), "Channel already closed");
554 555
    auto& state = get_channel_state();
    return state.options.get_option(name);
556 557
}

558
void ChannelImpl::set_option(std::string name, size_t value) {
559
    MGB_LOCK_GUARD(m_spin);
560
    mgb_assert(check_available(), "Channel already closed");
561 562
    auto& state = get_channel_state();
    state.options.set_option(name, value);
563 564 565 566 567 568 569 570 571
    // FIXME
    if (name == "enable_dtr_auto_drop" && value) {
        auto custom_allocator = [&](CompNode device, size_t size) {
            auto blob = Blob::make(device, size);
            alloc_tensor_with_evict(blob.get());
            return blob->storage();
        };
        BlobManager::inst()->set_allocator(custom_allocator);
    }
572 573 574 575 576 577 578 579 580 581
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), SetOption{name, value},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                SetOption{name, value},
        });
    }
582 583
}

584 585 586 587 588 589
void ChannelImpl::clear_candidates() {
    MGB_LOCK_GUARD(m_spin);
    mgb_assert(check_available(), "Channel already closed");
    m_dtr.candidates.clear();
}

590
TensorInfo* ChannelImpl::alloc() {
591
    auto& state = get_channel_state();
M
Megvii Engine Team 已提交
592
    auto info = [this] {
593
        MGB_LOCK_GUARD(m_pool_spin);
594
        return m_pool.alloc();
595 596 597
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
598
        size_t tensor_id = state.stack_manager.current()->next_id("tensor");
M
Megvii Engine Team 已提交
599 600
        info->name =
                state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
601
    }
602
    return info;
603 604
}

605
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
M
Megvii Engine Team 已提交
606
    m_valid_handle.insert(reinterpret_cast<Handle>(info));
607
    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
608
    info->status = TensorInfo::Allocated;
609
    info->desc = std::move(desc);
610 611
}

M
Megvii Engine Team 已提交
612
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
613 614
    if (!ptr->producer) {
        if (user) {
M
Megvii Engine Team 已提交
615 616 617 618
            mgb_log_warn(
                    "the input that produced tensor %p has been deleted, this drop "
                    "operation will be ignored",
                    ptr);
619 620 621 622 623 624 625
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
626
    ptr->status = TensorInfo::Dropped;
627 628 629
    release_tensor(ptr);
}

630
void ChannelImpl::free(TensorInfo* ptr) {
631 632
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
648
    MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
649
    SmallVector<TensorInfo*> inps;
650 651 652 653 654 655 656 657 658 659 660 661 662
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
663
    MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
664 665 666
}

void ChannelImpl::real_free(TensorInfo* ptr) {
667 668
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
669 670 671 672
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
673 674
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
675
        MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
676
    }
677
    MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
678
    ptr->status = TensorInfo::Deleted;
679
    MGB_LOCK_GUARD(m_pool_spin);
680 681 682
    m_pool.free(ptr);
}

683
ChannelImpl::ChannelImpl() : m_worker(this) {}
684

685 686 687
ChannelImpl::~ChannelImpl() {
    close();
}
688

689
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
690
    auto& state = get_worker_state();
691
    MGB_LOCK_GUARD(m_mutex);
692
    m_dtr.update_used_time(dest);
M
Megvii Engine Team 已提交
693 694
    MGB_RECORD_EVENT(
            TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
695
            ptr->raw_ptr_not_for_readwrite());
696
    // update tensor desc for static infer
697 698 699 700 701 702
    if (dest->desc.layout.ndim) {
        mgb_assert(
                dest->desc.layout.eq_shape(ptr->layout()),
                "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
                ptr->layout().to_string().c_str());
    }
703 704
    // in order to avoid performance impact,
    // memory forwarding is disabled when DTR is enabled
705
    if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
706 707
        ptr->to_contiguous_inplace();
    }
708
    dest->desc.comp_node = ptr->comp_node();
709
    dest->memory = ptr->blob()->size();
710
    dest->ptr = std::move(ptr);
711
    dest->evict_type = EvictType::NONE;
712
    dest->status = TensorInfo::Produced;
713 714
    if (dest->pinned == 0 &&
        dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
715 716
        m_dtr.insert_candidate(dest);
    }
717
    notify_tensor_unsafe(dest);
718 719
}

720
void ChannelImpl::release_tensor(TensorInfo* dest) {
721
    MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
722 723
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
724 725 726 727
    auto& state = get_worker_state();
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
        m_dtr.erase_candidate(dest);
    }
728 729
}

730
void ChannelImpl::regenerate(TensorInfo* dest) {
731
    if (dest->evict_type == EvictType::DROP) {
M
Megvii Engine Team 已提交
732 733
        auto&& path = dest->producer;
        m_apply_stack.push(
734 735
                {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
                 "dtr"});
M
Megvii Engine Team 已提交
736 737
        if (!m_applying)
            flush_apply_stack();
738 739 740
    }
}

741
void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
742 743
    using namespace ranges;
    using namespace ranges::views;
744
    auto& state = get_worker_state();
M
Megvii Engine Team 已提交
745 746
    bool profiling_device =
            Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
747
    uint64_t apply_id = cmd.id;
748
    SmallVector<TensorPtr> inputs;
749
    inputs.reserve(cmd.inputs.size());
750 751 752
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
753
        // refcnt ++, owners: [i->ptr, tensor_inputs]
754
        // tensor_inputs.push_back(i->ptr);
755
        inputs.push_back(i->ptr);
756
    }
M
Megvii Engine Team 已提交
757 758
    if (state.options.enable_dtr_auto_drop &&
        state.options.dtr_eviction_threshold > 0) {
759 760
        auto_evict(0);
    }
M
Megvii Engine Team 已提交
761
    auto apply_on_physical_tensor =
762
            [&](auto&& self, const OpDef& def, SmallVector<TensorPtr>&& inputs,
763 764
                SmallVector<LogicalTensorDesc>& output_descs,
                const bool& validated) -> SmallVector<TensorPtr> {
765
        if (def.trait()->make_forward_graph) {
766 767 768 769 770 771 772 773 774 775
            auto apply_functor = [&](std::shared_ptr<OpDef> op,
                                     SmallVector<TensorPtr> inputs,
                                     size_t nr_outputs) -> SmallVector<TensorPtr> {
                auto opname = op->trait()->make_name(*op);
                imperative_log_profile_begin(opname.c_str());
                auto outputs = self(self, *op, std::move(inputs), output_descs, false);
                imperative_log_profile_end(opname.c_str());
                return outputs;
            };
            auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
776 777
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
M
Megvii Engine Team 已提交
778
            for (auto&& input : inputs) {
779
                input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
780
            }
781
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
782 783
            auto outputs = forward_graph.apply<TensorPtr>(
                    inputs, apply_functor, const_functor);
784 785
            return outputs;
        }
786 787 788 789 790 791 792 793 794 795
        // Check Input Layout
        // Get the input layout constraints, and if the constraint is not satisfied
        // inplace update the layout and blob to make the tensor contiguous
        auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
        for (size_t idx = 0; idx < inputs.size(); ++idx) {
            auto&& layout_checker = constraints[idx];
            if (layout_checker) {
                inputs[idx]->to_contiguous_inplace(layout_checker);
            }
        }
796
        auto outputs = OpDef::apply_on_physical_tensor(
797
                def, std::move(inputs), output_descs, validated);
798 799 800 801 802
        for (auto& o : outputs) {
            o->set_ready_event(
                    record_event(o->comp_node(), def.same_type<imperative::Barrier>()));
        }
        return outputs;
803
    };
804
    MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
805 806 807 808
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
809 810 811
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
812
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
813 814 815
            }
        }
    }
M
Megvii Engine Team 已提交
816
    for (auto* input : cmd.inputs) {
817
        auto input_id = input->id;
818 819 820
        MGB_RECORD_EVENT(OpInputEvent, input_id);
        MGB_RECORD_EVENT(TensorUsageEvent, input_id);
        MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
821 822
    }
    // Before wait
M
Megvii Engine Team 已提交
823
    // TODO: split operator wait and execute so that OpWait could be corrected recorded.
824
    // Before execute
M
Megvii Engine Team 已提交
825
    for (auto&& [device, kernel_id] : kernels) {
826
        MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
M
Megvii Engine Team 已提交
827
        MGB_RECORD_EVENT_IF(
828 829
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
830 831
    }
    // Apply op
832
    SmallVector<LogicalTensorDesc> output_descs;
833 834 835 836 837 838 839
    bool validated = cmd.validated;
    if (!state.options.enable_dtr_auto_drop) {
        for (auto i : cmd.outputs) {
            output_descs.push_back(i->desc);
        }
    } else {
        validated = false;
840
    }
841
    // Here std::move is REQUIRED for removing duplicated references.
842
    auto outputs = apply_on_physical_tensor(
843
            apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
844
            validated);
845
    // After execute
M
Megvii Engine Team 已提交
846 847
    for (auto&& [device, kernel_id] : kernels) {
        MGB_RECORD_EVENT_IF(
848 849
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
850
        MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
851 852
    }
    // End profiling operator
853 854
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
855
        auto output = cmd.outputs[i];
856
        if (mgb_unlikely(output == nullptr)) {
857 858
            MGB_RECORD_EVENT(OpOutputEvent, 0);
            MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
859
        } else if (mgb_unlikely(output->ptr != nullptr)) {
860 861
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
862
        } else {
863
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
864
            produce_tensor(output, outputs[i]);
865
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
866
            sample_on_device(output->desc.comp_node, false);
867 868 869 870 871 872 873 874
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
875
        for (auto i : outputs) {
876
            estimate_compute_time += i->blob()->size();
877 878 879 880 881 882 883
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
884
        m_dtr.unpin(cmd.inputs, state);
885
    }
886
    MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
887
    // End profiling operator
888
}
889

890 891
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
892
    auto& state = get_worker_state();
893
    while (!m_apply_stack.empty()) {
M
Megvii Engine Team 已提交
894 895
        auto& [cmd, idx, recomp, reason] =
                m_apply_stack.top();  // cmd.inputs[0~idx-1] is in memory
896 897 898 899 900
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
M
Megvii Engine Team 已提交
901 902
                MGB_RECORD_EVENT(
                        TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
903 904 905
            }
        }
        bool regen = false;
M
Megvii Engine Team 已提交
906
        for (size_t i = idx; i < cmd.inputs.size(); i++) {
907 908 909 910 911 912
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
M
Megvii Engine Team 已提交
913
                regenerate(p);  // add ApplyOp to the stack
914 915 916 917
                regen = true;
                break;
            }
        }
M
Megvii Engine Team 已提交
918 919
        if (regen)
            continue;
920
        // the required input tensors are already in memory
M
Megvii Engine Team 已提交
921 922
        auto [cmd_backup, recomp_backup, reason_backup] =
                std::make_tuple(cmd, recomp, reason);
923
        m_apply_stack.pop();
924
        do_apply_op(cmd_backup, reason_backup);
925
        if (recomp_backup) {
M
Megvii Engine Team 已提交
926 927 928
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, recomp_backup->id,
                    TensorCommandKind::ReGen);
929 930
            for (auto o : cmd_backup.outputs) {
                if (o) {
931 932 933 934
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
935
    }
936
    m_applying = false;
937 938
}

939
bool ChannelImpl::auto_evict(size_t force_num) {
940
    auto& state = get_worker_state();
941
    if (!m_dtr.comp_node.valid()) {
942
        return false;
943 944
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
945
    size_t flag = false;
M
Megvii Engine Team 已提交
946 947 948
    while ((state.options.dtr_eviction_threshold > 0 &&
            current_memory > state.options.dtr_eviction_threshold) ||
           force_num > 0) {
949
        MGB_RECORD_EVENT(AutoEvictEvent);
950
        sample_on_device(m_dtr.comp_node, false);
951
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
952
        if (!best) {
953
            MGB_RECORD_EVENT(AutoEvictFinishEvent);
954 955 956 957
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
958
            if (force_num > 0) {
M
Megvii Engine Team 已提交
959
                force_num--;
960 961
            }
            flag = true;
962 963 964 965
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
966
        }
967
        sample_on_device(m_dtr.comp_node, false);
968
        MGB_RECORD_EVENT(AutoEvictFinishEvent);
969
    }
970
    return flag;
971 972
}

973 974
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
M
Megvii Engine Team 已提交
975
    for (auto* user : users) {
976 977
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
M
Megvii Engine Team 已提交
978 979 980 981 982
        for (auto* output : outputs) {
            // When a `ComputePath` is detach from it's input,
            // there is no need to reserve it,
            // so we detach all output of this path
            // to decrease it's `ref_cnt` to zero.
983 984 985 986 987
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
M
Megvii Engine Team 已提交
988 989
            for (auto* input : inputs) {
                input->ref_cnt--;
990
            }
991
        }
992
        // now user is dead
993
    }
994
    mgb_assert(dest->users.empty(), "ComputePath leaking");
995 996
}

997 998 999 1000
bool ChannelImpl::check_available() {
    return !m_closed;
}

1001 1002 1003 1004 1005
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
1006
    MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
1007
    bool require_host = prop == TensorProp::HostValue;
M
Megvii Engine Team 已提交
1008
    auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
1009 1010
    bool wait_host = false;
    if (require_host && !host_available()) {
1011 1012
        // avoid dead lock
        lock.unlock();
1013 1014 1015 1016 1017 1018 1019 1020 1021 1022
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), GetValue{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    GetValue{info},
            });
        }
1023
        lock.lock();
1024
        wait_host = true;
1025
    }
1026 1027
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
1028
        return require_host ? host_available() : static_cast<bool>(info->ptr);
1029
    });
1030
    MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
1031
    m_waitee = nullptr;
1032
    if (wait_host) {
1033 1034 1035
        auto err = info->ptr->comp_node().check_async_error();
        mgb_assert(!err, "%s", err->what());
    }
1036 1037 1038 1039 1040
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
1041
        MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
1042
        m_cv.notify_all();
1043
    }
1044 1045 1046 1047
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
M
Megvii Engine Team 已提交
1048
    for (auto* handle : m_valid_handle) {
1049 1050
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
1051
    }
1052
    return valid_tensors;
1053 1054
}

1055
void ChannelImpl::alloc_tensor_with_evict(OwnedBlob* x) {
1056
    bool in_worker = (get_worker_tid() == std::this_thread::get_id());
1057 1058 1059 1060 1061 1062
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
M
Megvii Engine Team 已提交
1063 1064
            if (!evict_suc)
                return false;
1065 1066 1067 1068
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
1069 1070 1071
    if (in_worker) {
        reserve_size(x->size());
    }
1072
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
1073 1074
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
1075 1076 1077 1078 1079 1080 1081 1082
        if (in_worker) {
            while (!suc) {
                if (!auto_evict(1)) {
                    break;
                }
                MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
                MGB_CATCH(MemAllocError&, { continue; });
                suc = true;
1083 1084 1085 1086
            }
        }
        if (!suc) {
            set_log_level(pre_level);
M
Megvii Engine Team 已提交
1087 1088 1089
            mgb_log_warn(
                    "reallocating all cuda memory to alleviate fragmentation, the "
                    "performance may be affected");
1090
            set_log_level(LogLevel::NO_LOG);
1091
            imperative_log_profile_begin("defrag");
1092
            BlobManager::inst()->defrag(x->comp_node());
1093
            imperative_log_profile_end("defrag");
1094
            BlobManager::inst()->alloc_direct(x, x->size());
1095 1096 1097 1098 1099
        }
    });
    set_log_level(pre_level);
}

1100
void ChannelImpl::process_one_task(Command& icmd) {
1101 1102
    using namespace ranges;
    using namespace ranges::views;
1103
    auto& state = get_worker_state();
1104
    auto& options = state.options;
M
Megvii Engine Team 已提交
1105
    // TODO: remove std::visit for support osx 10.12
1106
    auto cmd_visitor = [&](const auto& cmd) {
M
Megvii Engine Team 已提交
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, Put>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
                                      : Tensor::make(cmd.value);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            produce_tensor(cmd.dest, std::move(value));
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
            sample_on_device(cmd.dest->desc.comp_node, false);
        } else if constexpr (std::is_same_v<T, ApplyOp>) {
            for (auto& i : cmd.inputs) {
1124
                if (mgb_unlikely(i->invalid)) {
M
Megvii Engine Team 已提交
1125 1126 1127
                    MGB_LOCK_GUARD(m_mutex);
                    for (auto& i : cmd.outputs) {
                        i->invalid = true;
1128
                    }
M
Megvii Engine Team 已提交
1129 1130 1131
                    return;
                }
            }
1132 1133 1134 1135 1136 1137 1138 1139
            if (state.options.enable_dtr_auto_drop) {
                m_apply_stack.push({cmd, 0, nullptr, "cmd"});
                flush_apply_stack();
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
                        continue;
                    }
M
Megvii Engine Team 已提交
1140
                    output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1141
                }
1142 1143
            } else {
                do_apply_op(cmd, "cmd");
M
Megvii Engine Team 已提交
1144 1145 1146 1147 1148 1149 1150
            }
            if (state.options.enable_drop && state.options.record_computing_path) {
                auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                    auto& input = std::get<0>(tuple2);
                    auto& output = std::get<1>(tuple2);
                    if (!input->ptr || !output->ptr) {
                        return false;
1151
                    }
M
Megvii Engine Team 已提交
1152 1153 1154 1155 1156 1157 1158
                    return input->ptr->blob()->storage() ==
                           output->ptr->blob()->storage();
                };
                // FIXME: do not use opname as identifier
                auto get_name = [](const OpDef& opdef) {
                    if (auto attr = opdef.try_cast_final<OprAttr>()) {
                        return attr->type.c_str();
1159
                    }
M
Megvii Engine Team 已提交
1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172
                    return opdef.dyn_typeinfo()->name;
                };

                auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
                    return info->desc.comp_node != comp_node;
                };

                bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                bool inplace =
                        any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);

                if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                    TensorInfo::ComputePath::make(
1173
                            cmd.id, cmd.op, cmd.inputs, cmd.outputs);
M
Megvii Engine Team 已提交
1174 1175
                    size_t detach_cnt = 0;
                    if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
1176
                        cmd.outputs.size() == 6) {
M
Megvii Engine Team 已提交
1177 1178
                        cmd.outputs[0]->detach_producer();  // detach running_mean
                        cmd.outputs[1]->detach_producer();  // detach running_var
1179
                        for (auto input : cmd.inputs) {
M
Megvii Engine Team 已提交
1180
                            input->ref_cnt -= 2;
1181 1182
                        }
                    }
M
Megvii Engine Team 已提交
1183 1184 1185 1186 1187 1188 1189
                    for (auto output : cmd.outputs) {
                        if (output->producer &&
                            !output->size_exceeds_thd(
                                    state.options.dtr_evictee_minimum_size)) {
                            output->detach_producer();
                            detach_cnt++;
                        }
1190
                    }
M
Megvii Engine Team 已提交
1191 1192
                    for (auto input : cmd.inputs) {
                        input->ref_cnt -= detach_cnt;
1193
                    }
1194
                }
1195
            }
M
Megvii Engine Team 已提交
1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211
        } else if constexpr (std::is_same_v<T, Del>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
            CompNode device = cmd.dest->desc.comp_node;
            uint64_t tensor_id = cmd.dest->id;
            free(cmd.dest);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
            sample_on_device(device, false);
        } else if constexpr (std::is_same_v<T, GetValue>) {
            if (cmd.dest->invalid)
                return;
            imperative_log_profile_begin("GetValue");
            if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                regenerate(cmd.dest);
            }
            cmd.dest->ptr->fetch_value();
1212
            MGB_LOCK_GUARD(m_mutex);
M
Megvii Engine Team 已提交
1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229
            notify_tensor_unsafe(cmd.dest);
            imperative_log_profile_end("GetValue");
        } else if constexpr (std::is_same_v<T, Drop>) {
            if (cmd.dest->invalid)
                return;
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
            do_drop(cmd.dest, true);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
        } else if constexpr (std::is_same_v<T, SetOption>) {
            options.set_option(cmd.key, cmd.value);
        } else if constexpr (std::is_same_v<T, StartProfile>) {
            MGB_RECORD_EVENT(StartProfileEvent);
            CompNode::sync_all();
            for (auto* info : cmd.capture_tensors) {
                MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                if (info->status == TensorInfo::Produced) {
1230
                    // TODO: handle drop
M
Megvii Engine Team 已提交
1231 1232 1233
                    MGB_RECORD_EVENT(
                            TensorProduceEvent, info->id, info->desc.layout,
                            info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
1234 1235
                }
            }
M
Megvii Engine Team 已提交
1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250
            CompNode::foreach ([&](CompNode device) {
                sample_on_device(device, true);
                MGB_RECORD_EVENT_IF(
                        (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                        Timer::record_device(device));
            });
            MGB_RECORD_EVENT(StartProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, StopProfile>) {
            MGB_RECORD_EVENT(StopProfileEvent);
            for (auto* info : cmd.escape_tensors) {
                bool has_value = info->status == TensorInfo::Produced;
                if (has_value) {
                    MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
                }
                MGB_RECORD_EVENT(TensorEraseEvent, info->id);
1251
            }
M
Megvii Engine Team 已提交
1252 1253 1254 1255 1256 1257 1258 1259 1260
            CompNode::foreach (
                    [&](CompNode device) { sample_on_device(device, true); });
            MGB_RECORD_EVENT(StopProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, PushScope>) {
            MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
        } else if constexpr (std::is_same_v<T, PopScope>) {
            MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
        } else {
            static_assert(!std::is_same_v<T, T>);
1261
        }
M
Megvii Engine Team 已提交
1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288
    };
    std::visit(
            [&](const auto& cmd) {
                using T = std::decay_t<decltype(cmd)>;
                if (!options.catch_worker_execption) {
                    cmd_visitor(cmd);
                    return;
                }
                try {
                    cmd_visitor(cmd);
                } catch (...) {
                    MGB_LOCK_GUARD(m_mutex);
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        for (auto oup : cmd.outputs) {
                            oup->invalid = true;
                        }
                    } else if constexpr (std::is_same_v<T, Put>) {
                        cmd.dest->invalid = true;
                    }
                    m_worker_exc = std::current_exception();
                    MGB_RECORD_EVENT(WorkerExceptionEvent);
                    if (m_waitee) {
                        notify_tensor_unsafe(m_waitee);
                    }
                }
            },
            icmd.data);
1289 1290 1291 1292
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1293 1294
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1295 1296
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1297 1298 1299 1300 1301
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1302 1303
    }
}
1304

1305
void ChannelImpl::start_profile() {
1306
    MGB_LOCK_GUARD(m_spin);
1307
    mgb_assert(check_available(), "Channel already closed");
1308 1309
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
1310 1311 1312 1313 1314 1315 1316 1317 1318 1319
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StartProfile{std::move(capture_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StartProfile{std::move(capture_tensors)},
            });
        }
1320
    }
1321 1322
}

1323
void ChannelImpl::stop_profile() {
1324
    MGB_LOCK_GUARD(m_spin);
1325
    mgb_assert(check_available(), "Channel already closed");
1326 1327
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
1328 1329 1330 1331 1332 1333 1334 1335 1336 1337
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StopProfile{std::move(escape_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StopProfile{std::move(escape_tensors)},
            });
        }
1338
    }
1339 1340 1341
}

void ChannelImpl::push_scope(std::string name) {
1342
    MGB_LOCK_GUARD(m_spin);
1343
    mgb_assert(check_available(), "Channel already closed");
1344
    auto& state = get_channel_state();
1345
    state.stack_manager.enter(name);
1346
    MGB_RECORD_EVENT(ScopeEvent, name);
1347 1348 1349 1350 1351 1352 1353 1354 1355 1356
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PushScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PushScope{name},
        });
    }
1357 1358 1359
}

void ChannelImpl::pop_scope(std::string name) {
1360
    MGB_LOCK_GUARD(m_spin);
1361
    mgb_assert(check_available(), "Channel already closed");
1362
    auto& state = get_channel_state();
1363
    state.stack_manager.exit(name);
1364
    MGB_RECORD_EVENT(ScopeFinishEvent, name);
1365 1366 1367 1368 1369 1370 1371 1372 1373 1374
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PopScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PopScope{name},
        });
    }
1375 1376
}

1377
void ChannelImpl::assert_in_channel() {
M
Megvii Engine Team 已提交
1378 1379 1380
    mgb_assert(
            get_worker_tid() != std::this_thread::get_id(),
            "this method cannot be called in worker thread");
1381 1382 1383
}

void ChannelImpl::assert_in_worker() {
M
Megvii Engine Team 已提交
1384 1385 1386
    mgb_assert(
            get_worker_tid() == std::this_thread::get_id(),
            "this method can only be called in worker thread");
1387 1388
}

1389
void ChannelImpl::sample_on_device(CompNode device, bool force) {
1390 1391 1392
    if (!Profiler::is_profiling()) {
        return;
    }
1393 1394
    if (!force) {
        thread_local int last_sample_id = 0;
1395
        int sample_rate = Profiler::get_option("sample_rate", 0);
1396 1397 1398 1399
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
1400
    MGB_RECORD_EVENT(SampleDeviceEvent, device);
1401
    auto [total, free] = device.get_mem_status_bytes();
1402
    MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
1403 1404
}

1405 1406 1407
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
1408
        erase_candidate(i);
1409 1410 1411
    }
}

1412 1413
void ChannelImpl::DynamicSublinear::unpin(
        const SmallVector<TensorInfo*>& vec, WorkerState& state) {
1414 1415
    for (auto i : vec) {
        i->unpin();
1416 1417 1418 1419 1420
        if (i->pinned == 0 &&
            i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
            i->cand_index == UINT_MAX) {
            insert_candidate(i);
        }
1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

M
Megvii Engine Team 已提交
1467 1468
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
        bool enable_dtr_sqrt_sampling = false) {
1469 1470 1471
    if (candidates.empty())
        return nullptr;

1472 1473
    double min_msps = -1;
    TensorInfo* best = nullptr;
1474 1475
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
M
Megvii Engine Team 已提交
1476 1477
        while (sz * sz <= candidates.size())
            sz++;
1478
        sz--;
1479 1480 1481
    } else {
        sz = candidates.size();
    }
1482 1483 1484 1485 1486 1487 1488

    size_t ti = rand() % sz;
    for (size_t vi = 0; vi < sz; vi++) {
        if (!enable_dtr_sqrt_sampling) {
            ti = vi;
        }
        auto i = candidates[ti];
1489
        if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
1490
            double neighbor_cost = estimate_neighbor_cost(i);
M
Megvii Engine Team 已提交
1491 1492 1493 1494
            size_t begin_ptr =
                    reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(
                    begin_ptr, begin_ptr + i->ptr->blob()->size());
1495
            double free_mem = side_info.first + side_info.second;
M
Megvii Engine Team 已提交
1496 1497
            double msps = i->eval_func(
                    neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
1498 1499 1500 1501 1502
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1503 1504 1505 1506 1507
        if (enable_dtr_sqrt_sampling) {
            ti += rand() % sz;
            if (ti > candidates.size())
                break;
        }
1508 1509 1510 1511
    }
    return best;
}

M
Megvii Engine Team 已提交
1512 1513
void ChannelImpl::DynamicSublinear::merge(
        std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
1514 1515 1516 1517 1518 1519 1520 1521 1522
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

M
Megvii Engine Team 已提交
1523 1524
std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
        std::shared_ptr<DsuNode>& x) {
1525 1526 1527 1528 1529 1530 1531 1532 1533
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
1534 1535 1536 1537 1538 1539
    // tensor to be inserted must be brand new
    mgb_assert(
            ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
            ptr->cand_index);
    ptr->cand_index = candidates.size();
    candidates.push_back(ptr);
1540 1541 1542 1543 1544 1545
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
1546 1547 1548 1549 1550 1551
    // close dtr will just clear candidates, so nothing to erase
    if (candidates.empty()) {
        ptr->cand_index = UINT_MAX;
        return;
    }
    // some tensors may be erased already, just skip them
1552 1553 1554 1555 1556 1557
    if (ptr->cand_index != UINT_MAX) {
        std::swap(candidates[ptr->cand_index], candidates.back());
        candidates[ptr->cand_index]->cand_index = ptr->cand_index;
        candidates.pop_back();
        ptr->cand_index = UINT_MAX;
    }
1558 1559 1560 1561 1562
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}