interpreter_impl.cpp 54.2 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21
#include "megbrain/imperative/ops/utility.h"
22
#include "megbrain/imperative/utils/stats.h"
23 24
#include "megbrain/imperative/utils/to_string.h"

25
#include "../blob_manager_impl.h"
26 27 28
#include "../event_pool.h"
#include "../op_trait.h"

29 30 31 32 33
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

34
namespace {
M
Megvii Engine Team 已提交
35 36 37 38 39 40 41 42
auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
    SmallVector<uint64_t> tid;
    for (auto* ptinfo : tinfo) {
        tid.push_back(ptinfo->id);
    }
    return tid;
};
}  // namespace
43

44
namespace mgb {
M
Megvii Engine Team 已提交
45
using namespace profiler;
46 47
}

48 49 50 51 52
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
53 54 55 56 57 58 59

namespace mgb {

/**
 * USAGE
 *
 *   header:
60
 *     namespace mgb { void imperative_log_profile(const char* message); }
61 62 63 64 65
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
66
SYMBOL_EXPORT
67
void imperative_log_profile_begin(const char* message) {
68
    MGB_RECORD_EVENT(CustomEvent, std::string{message});
69 70
}

71
SYMBOL_EXPORT
72
void imperative_log_profile_end(const char* message) {
73
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
74 75
}

76
SYMBOL_EXPORT
M
Megvii Engine Team 已提交
77
void imperative_log_profile(const char* message) {
78 79 80 81
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

82 83 84 85
SYMBOL_EXPORT
void imperative_log_profile_begin(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
    MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
M
Megvii Engine Team 已提交
86 87
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
88 89 90 91 92
}

SYMBOL_EXPORT
void imperative_log_profile_end(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
M
Megvii Engine Team 已提交
93 94
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
95 96 97
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
}

M
Megvii Engine Team 已提交
98
}  // namespace mgb
99

100 101 102 103 104 105 106 107 108 109 110 111 112 113
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

114 115 116
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
117
    auto custom_allocator = [&](CompNode device, size_t size) {
118 119 120
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
121 122 123
    };
    OpDef::set_allocator(custom_allocator);
    BlobManager::inst()->set_allocator(custom_allocator);
124 125
}

126
// Do not use m_xxx_state directly
127 128 129
#define m_channel_state
#define m_worker_state

130 131 132 133 134 135 136 137 138
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

139
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
140
    MGB_LOCK_GUARD(m_spin);
141
    mgb_assert(check_available(), "Channel already closed");
142
    auto& state = get_channel_state();
143
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
144
    auto info = put_impl(value, no_cache);
M
Megvii Engine Team 已提交
145
    return reinterpret_cast<Handle>(info);
146 147 148
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
149 150 151 152 153
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
154
    auto info = alloc();
155 156 157 158 159 160
    constexpr int size_threshold = TensorShape::MAX_NDIM;
    init(info, {value.layout(), value.comp_node()});
    if (value.layout().total_nr_elems() <= size_threshold) {
        info->h_value = value;
        info->desc.value = value.proxy_to_default_cpu();
    }
161 162 163 164 165 166 167 168 169 170
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Put{info, value, no_cache},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Put{info, value, no_cache},
        });
    }
171
    if (m_async_level == 0) {
172
        sync_impl();
173
        info->desc.comp_node.sync();
174 175
        auto err = info->desc.comp_node.check_async_error();
        mgb_assert(!err, "%s", err->what());
176
    }
177 178 179
    return info;
}

180
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
181
    MGB_LOCK_GUARD(m_spin);
182
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
183
    return reinterpret_cast<Handle>(put_impl(data, hvalue));
184
}
M
Megvii Engine Team 已提交
185 186
TensorInfo* ChannelImpl::put_impl(
        const DeviceTensorND& data, const HostTensorND& hvalue) {
187
    auto& state = get_channel_state();
188
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
M
Megvii Engine Team 已提交
189
    auto info = alloc();
190
    MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
191
    constexpr int size_threshold = TensorShape::MAX_NDIM;
192
    init(info, {data.layout(), data.comp_node()});
193 194 195
    if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
        info->desc.value = hvalue.proxy_to_default_cpu();
    }
196
    info->ptr = Tensor::make(data, hvalue);
M
Megvii Engine Team 已提交
197 198 199
    MGB_RECORD_EVENT(
            TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
            data.raw_ptr());
200
    info->status = TensorInfo::Produced;
201
    MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
M
Megvii Engine Team 已提交
202 203 204
    return info;
}

205
void ChannelImpl::del(Handle handle) {
206
    MGB_LOCK_GUARD(m_spin);
M
Megvii Engine Team 已提交
207
    if (!check_available()) {
208 209
        return;
    }
210 211 212 213
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
214 215 216
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
217 218 219 220 221 222 223 224 225 226
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Del{info},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Del{info},
        });
    }
227 228
}

229
void ChannelImpl::drop(Handle handle) {
230
    MGB_LOCK_GUARD(m_spin);
231
    mgb_assert(check_available(), "Channel already closed");
232 233
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
M
Megvii Engine Team 已提交
234 235
        mgb_assert(
                m_valid_handle.find(handle) != m_valid_handle.end(),
236
                "invalid handle: %p", handle);
237
        auto* info = reinterpret_cast<TensorInfo*>(handle);
238 239 240 241 242 243 244 245 246 247
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), Drop{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    Drop{info},
            });
        }
248 249 250
    }
}

251
void ChannelImpl::dispatch_default_cpu(
M
Megvii Engine Team 已提交
252
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
253 254
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
255
    auto& state = get_channel_state();
256 257

    auto name = op->trait()->make_name(*op);
258
    auto _ = StackManager::Guard(name, &state.stack_manager);
259

M
Megvii Engine Team 已提交
260 261
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
262
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
263

264 265
    SmallVector<DeviceTensorND> input_tensornds;
    CompNode output_cn;
266 267
    {
        MGB_LOCK_GUARD(m_mutex);
268
        for (auto&& info : input_infos) {
269
            auto input_cn = info->desc.comp_node;
270
            if (!output_cn.valid()) {
271 272 273 274 275 276
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
M
Megvii Engine Team 已提交
277 278
                input_tensornds.emplace_back(
                        info->ptr->get_value().proxy_to_default_cpu());
279
            } else {
280
                // We assign h_value before drop ptr
281 282
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
283 284 285 286 287 288 289 290 291
            }
        }
    }

    SmallVector<DeviceTensorND> output_tensornds;
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
M
Megvii Engine Team 已提交
292 293
        output_tensornds.emplace_back(
                HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
294 295
    }

296
    uint64_t op_id = Profiler::next_id();
297

298 299 300 301 302 303 304 305 306
    if (op->trait()->apply_on_device_tensornd) {
        OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
    } else {
        // proxy to apply_on_physical_tensor
        SmallVector<TensorPtr> input_tensors;
        for (auto&& input_tensornd : input_tensornds) {
            input_tensors.push_back(Tensor::make(
                    input_tensornd, HostTensorND::make_proxy(input_tensornd)));
        }
307 308
        auto output_tensors = OpDef::apply_on_physical_tensor(
                *op, input_tensors, output_descs, validated);
309 310 311 312
        for (size_t i = 0; i < output_tensors.size(); ++i) {
            output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
        }
    }
313 314 315

    SmallVector<TensorInfo*> output_infos;
    for (auto&& tensornd : output_tensornds) {
M
Megvii Engine Team 已提交
316 317
        HostTensorND host_tensornd =
                HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
318
        // use `put` for consistency
319
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
320
        mgb_assert(info->desc.layout.ndim != 0);
321
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
322
        outputs->push_back(reinterpret_cast<Handle>(info));
323
    }
M
Megvii Engine Team 已提交
324
    auto op_info_getter = [op] {
325 326
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
M
Megvii Engine Team 已提交
327
        for (auto&& [key, value] : props) {
328 329 330 331
            op_info[key] = value;
        }
        return op_info;
    };
M
Megvii Engine Team 已提交
332 333 334
    MGB_RECORD_EVENT(
            OpDispatchEvent, op_id, name, op_info_getter, tinfo_to_tid(input_infos),
            tinfo_to_tid(output_infos), state.stack_manager.dump());
335
}
336

337
void ChannelImpl::dispatch_kernel(
M
Megvii Engine Team 已提交
338
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
339 340
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
341
    auto& state = get_channel_state();
342 343
    auto& options = state.options;

344 345 346
    auto name = op->trait()->make_name(*op);
    auto _ = StackManager::Guard{name, &state.stack_manager};

M
Megvii Engine Team 已提交
347 348
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
349
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
350

351 352 353 354
    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());

    outputs->reserve(output_descs.size());
355 356
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
357
        auto info = alloc();
358
        init(info, std::move(desc));
359 360 361
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
M
Megvii Engine Team 已提交
362
                                    .proxy_to_comp_node(desc.comp_node);
363
        }
364
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
365
        outputs->push_back(reinterpret_cast<Handle>(info));
366
    }
367 368 369
    ApplyOp cmd{
            Profiler::next_id(), std::move(op), std::move(input_infos),
            std::move(output_infos), validated};
370
    if (Profiler::is_profiling()) {
371 372 373 374 375 376 377 378
        auto op_info_getter = [op = cmd.op] {
            std::unordered_map<std::string, std::string> op_info;
            auto props = OpDef::props(*op);
            for (auto&& [key, value] : props) {
                op_info[key] = value;
            }
            return op_info;
        };
379
        MGB_RECORD_EVENT(
380 381
                OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
                tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
382
        m_worker.add_task(
383
                {Profiler::next_id(), std::move(cmd),
384 385 386 387
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
388
                std::move(cmd),
389 390
        });
    }
391
    if (!validated && options.async_level == 1) {
392
        sync_impl();
393
    } else if (options.async_level == 0) {
394
        sync_impl();
395
        // check device error
396
        for (auto&& oup : *outputs) {
397 398
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
399 400
            auto err = info->ptr->comp_node().check_async_error();
            mgb_assert(!err, "%s", err->what());
401
        }
402
    }
403 404 405
}

SmallVector<Handle> ChannelImpl::apply_op(
M
Megvii Engine Team 已提交
406
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
407
    MGB_LOCK_GUARD(m_spin);
408
    mgb_assert(check_available(), "Channel already closed");
409 410 411 412 413 414 415 416 417 418 419
    auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
    if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
        size_t ndim = input->desc.layout.ndim;
        auto& gvs = op->cast_final_safe<GetVarShape>();
        if (gvs.axis == MEGDNN_MAX_NDIM) {
            HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
            DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
            cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
            return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
        }
    }
420 421 422 423
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
M
Megvii Engine Team 已提交
424
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
425
    auto& state = get_channel_state();
426
    for (auto i : inputs) {
M
Megvii Engine Team 已提交
427 428 429
        mgb_assert(
                m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
                i);
430 431 432 433
    }
    SmallVector<TensorInfo*> input_infos;
    SmallVector<LogicalTensorDesc> input_descs;
    {
434
        MGB_LOCK_GUARD(m_info_spin);
435 436
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
M
Megvii Engine Team 已提交
437 438 439
            mgb_assert(
                    !info->invalid,
                    "an input tensor is unusable due to previous error");
440 441 442 443 444 445
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
446
    DispatchMode dispatch_mode = state.options.enable_host_compute
M
Megvii Engine Team 已提交
447 448
                                       ? OpDef::decide_dispatch_mode(*op, input_descs)
                                       : DispatchMode::KERNEL;
449
    switch (dispatch_mode) {
450 451 452 453 454 455 456 457 458
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
459 460 461
    return outputs;
}

462
HostTensorND ChannelImpl::get_value(Handle handle) {
463
    MGB_LOCK_GUARD(m_spin);
464
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
465 466 467
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
468
    auto info = reinterpret_cast<TensorInfo*>(handle);
469
    // donnot use info->value_fetched, it's unsafe
470
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
471
    return wait_tensor(info, TensorProp::HostValue)->get_value();
472 473
}

474
TensorShape ChannelImpl::get_shape(Handle handle) {
475
    MGB_LOCK_GUARD(m_spin);
476
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
477 478 479
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
480 481 482 483
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
484
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
485 486 487 488
    mgb_assert(ret.ndim != 0);
    return ret;
}

489
DType ChannelImpl::get_dtype(Handle handle) {
490
    MGB_LOCK_GUARD(m_spin);
491
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
492 493 494
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
495
    auto info = reinterpret_cast<TensorInfo*>(handle);
496
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
497 498 499 500 501
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

502
CompNode ChannelImpl::get_device(Handle handle) {
503
    MGB_LOCK_GUARD(m_spin);
504
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
505 506 507
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
508
    auto info = reinterpret_cast<TensorInfo*>(handle);
509
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
510 511 512 513 514
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

515
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
516
    MGB_LOCK_GUARD(m_spin);
517
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
518 519 520
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
521
    auto info = reinterpret_cast<TensorInfo*>(handle);
522
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
523 524 525
}

void ChannelImpl::sync() {
526
    MGB_LOCK_GUARD(m_spin);
527
    mgb_assert(check_available(), "Channel already closed");
528 529 530 531
    sync_impl();
}

void ChannelImpl::sync_impl() {
532 533 534 535 536 537
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
538
    MGB_LOCK_GUARD(m_spin);
539 540 541 542
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
M
Megvii Engine Team 已提交
543
    for (auto* handle : valid_handles) {
544
        del_impl(handle);
545 546 547
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
548
    sync_impl();
549
    m_closed = true;
550 551
}

552
size_t ChannelImpl::get_option(std::string name) {
553
    MGB_LOCK_GUARD(m_spin);
554
    mgb_assert(check_available(), "Channel already closed");
555 556
    auto& state = get_channel_state();
    return state.options.get_option(name);
557 558
}

559
void ChannelImpl::set_option(std::string name, size_t value) {
560
    MGB_LOCK_GUARD(m_spin);
561
    mgb_assert(check_available(), "Channel already closed");
562 563
    auto& state = get_channel_state();
    state.options.set_option(name, value);
564 565 566 567 568 569 570 571 572 573
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), SetOption{name, value},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                SetOption{name, value},
        });
    }
574 575
}

576 577 578 579 580 581
void ChannelImpl::clear_candidates() {
    MGB_LOCK_GUARD(m_spin);
    mgb_assert(check_available(), "Channel already closed");
    m_dtr.candidates.clear();
}

582
TensorInfo* ChannelImpl::alloc() {
583
    auto& state = get_channel_state();
M
Megvii Engine Team 已提交
584
    auto info = [this] {
585
        MGB_LOCK_GUARD(m_pool_spin);
586
        return m_pool.alloc();
587 588 589
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
590
        size_t tensor_id = state.stack_manager.current()->next_id("tensor");
M
Megvii Engine Team 已提交
591 592
        info->name =
                state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
593
    }
594
    return info;
595 596
}

597
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
M
Megvii Engine Team 已提交
598
    m_valid_handle.insert(reinterpret_cast<Handle>(info));
599
    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
600
    info->status = TensorInfo::Allocated;
601
    info->desc = desc;
602 603
}

M
Megvii Engine Team 已提交
604
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
605 606
    if (!ptr->producer) {
        if (user) {
M
Megvii Engine Team 已提交
607 608 609 610
            mgb_log_warn(
                    "the input that produced tensor %p has been deleted, this drop "
                    "operation will be ignored",
                    ptr);
611 612 613 614 615 616 617
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
618
    ptr->status = TensorInfo::Dropped;
619 620 621
    release_tensor(ptr);
}

622
void ChannelImpl::free(TensorInfo* ptr) {
623 624
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
625 626 627 628 629 630 631 632 633 634 635 636 637 638 639
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
640
    MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
641
    SmallVector<TensorInfo*> inps;
642 643 644 645 646 647 648 649 650 651 652 653 654
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
655
    MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
656 657 658
}

void ChannelImpl::real_free(TensorInfo* ptr) {
659 660
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
661 662 663 664
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
665 666
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
667
        MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
668
    }
669
    MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
670
    ptr->status = TensorInfo::Deleted;
671
    MGB_LOCK_GUARD(m_pool_spin);
672 673 674
    m_pool.free(ptr);
}

675
ChannelImpl::ChannelImpl() : m_worker(this) {}
676

677 678 679
ChannelImpl::~ChannelImpl() {
    close();
}
680

681
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
682
    auto& state = get_worker_state();
683
    MGB_LOCK_GUARD(m_mutex);
684
    m_dtr.update_used_time(dest);
M
Megvii Engine Team 已提交
685 686
    MGB_RECORD_EVENT(
            TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
687
            ptr->dev_tensor(false).raw_ptr());
688
    // update tensor desc for static infer
689 690 691 692 693 694
    if (dest->desc.layout.ndim) {
        mgb_assert(
                dest->desc.layout.eq_shape(ptr->layout()),
                "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
                ptr->layout().to_string().c_str());
    }
695 696 697 698 699
    // in order to avoid performance impact,
    // memory forwarding is disabled when DTR is enabled
    if (state.options.enable_dtr_auto_drop) {
        ptr->to_contiguous_inplace();
    }
700 701
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
702
    dest->memory = ptr->blob()->size();
703
    dest->ptr = std::move(ptr);
704
    dest->evict_type = EvictType::NONE;
705
    dest->status = TensorInfo::Produced;
706 707
    if (dest->pinned == 0 &&
        dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
708 709
        m_dtr.insert_candidate(dest);
    }
710
    notify_tensor_unsafe(dest);
711 712
}

713
void ChannelImpl::release_tensor(TensorInfo* dest) {
714
    MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
715 716
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
717 718 719 720
    auto& state = get_worker_state();
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
        m_dtr.erase_candidate(dest);
    }
721 722
}

723
void ChannelImpl::regenerate(TensorInfo* dest) {
724
    if (dest->evict_type == EvictType::DROP) {
M
Megvii Engine Team 已提交
725 726
        auto&& path = dest->producer;
        m_apply_stack.push(
727 728
                {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
                 "dtr"});
M
Megvii Engine Team 已提交
729 730
        if (!m_applying)
            flush_apply_stack();
731 732 733
    }
}

734
void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
735 736
    using namespace ranges;
    using namespace ranges::views;
737
    auto& state = get_worker_state();
M
Megvii Engine Team 已提交
738 739
    bool profiling_device =
            Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
740
    uint64_t apply_id = cmd.id;
741
    SmallVector<TensorPtr> inputs;
742
    inputs.reserve(cmd.inputs.size());
743 744 745
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
746
        // refcnt ++, owners: [i->ptr, tensor_inputs]
747
        // tensor_inputs.push_back(i->ptr);
748
        inputs.push_back(i->ptr);
749
    }
M
Megvii Engine Team 已提交
750 751
    if (state.options.enable_dtr_auto_drop &&
        state.options.dtr_eviction_threshold > 0) {
752 753
        auto_evict(0);
    }
M
Megvii Engine Team 已提交
754
    auto apply_on_physical_tensor =
755
            [&](auto&& self, const OpDef& def, SmallVector<TensorPtr>&& inputs,
756 757
                SmallVector<LogicalTensorDesc>& output_descs,
                const bool& validated) -> SmallVector<TensorPtr> {
758
        if (def.trait()->make_forward_graph) {
759 760 761 762 763 764 765 766 767 768
            auto apply_functor = [&](std::shared_ptr<OpDef> op,
                                     SmallVector<TensorPtr> inputs,
                                     size_t nr_outputs) -> SmallVector<TensorPtr> {
                auto opname = op->trait()->make_name(*op);
                imperative_log_profile_begin(opname.c_str());
                auto outputs = self(self, *op, std::move(inputs), output_descs, false);
                imperative_log_profile_end(opname.c_str());
                return outputs;
            };
            auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
769 770
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
M
Megvii Engine Team 已提交
771
            for (auto&& input : inputs) {
772
                input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
773
            }
774
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
775 776
            auto outputs = forward_graph.apply<TensorPtr>(
                    inputs, apply_functor, const_functor);
777 778
            return outputs;
        }
779 780 781 782 783 784 785 786 787 788 789 790
        // Check Input Layout
        // Get the input layout constraints, and if the constraint is not satisfied
        // inplace update the layout and blob to make the tensor contiguous
        auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
        for (size_t idx = 0; idx < inputs.size(); ++idx) {
            auto&& layout_checker = constraints[idx];
            if (layout_checker) {
                inputs[idx]->to_contiguous_inplace(layout_checker);
            }
        }
        return OpDef::apply_on_physical_tensor(
                def, std::move(inputs), output_descs, validated);
791
    };
792
    MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
793 794 795 796
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
797 798 799
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
800
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
801 802 803
            }
        }
    }
M
Megvii Engine Team 已提交
804
    for (auto* input : cmd.inputs) {
805
        auto input_id = input->id;
806 807 808
        MGB_RECORD_EVENT(OpInputEvent, input_id);
        MGB_RECORD_EVENT(TensorUsageEvent, input_id);
        MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
809 810
    }
    // Before wait
M
Megvii Engine Team 已提交
811
    // TODO: split operator wait and execute so that OpWait could be corrected recorded.
812
    // Before execute
M
Megvii Engine Team 已提交
813
    for (auto&& [device, kernel_id] : kernels) {
814
        MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
M
Megvii Engine Team 已提交
815
        MGB_RECORD_EVENT_IF(
816 817
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
818 819
    }
    // Apply op
820
    SmallVector<LogicalTensorDesc> output_descs;
821 822 823 824 825 826 827
    bool validated = cmd.validated;
    if (!state.options.enable_dtr_auto_drop) {
        for (auto i : cmd.outputs) {
            output_descs.push_back(i->desc);
        }
    } else {
        validated = false;
828
    }
829
    // Here std::move is REQUIRED for removing duplicated references.
830
    auto outputs = apply_on_physical_tensor(
831
            apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
832
            validated);
833
    // After execute
M
Megvii Engine Team 已提交
834 835
    for (auto&& [device, kernel_id] : kernels) {
        MGB_RECORD_EVENT_IF(
836 837
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
838
        MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
839 840
    }
    // End profiling operator
841 842
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
843
        auto output = cmd.outputs[i];
844
        if (mgb_unlikely(output == nullptr)) {
845 846
            MGB_RECORD_EVENT(OpOutputEvent, 0);
            MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
847
        } else if (mgb_unlikely(output->ptr != nullptr)) {
848 849
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
850
        } else {
851
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
852
            produce_tensor(output, outputs[i]);
853
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
854
            sample_on_device(output->desc.comp_node, false);
855 856 857 858 859 860 861 862
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
863
        for (auto i : outputs) {
864
            estimate_compute_time += i->blob()->size();
865 866 867 868 869 870 871
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
872
        m_dtr.unpin(cmd.inputs, state);
873
    }
874
    MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
875
    // End profiling operator
876
}
877

878 879
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
880
    auto& state = get_worker_state();
881
    while (!m_apply_stack.empty()) {
M
Megvii Engine Team 已提交
882 883
        auto& [cmd, idx, recomp, reason] =
                m_apply_stack.top();  // cmd.inputs[0~idx-1] is in memory
884 885 886 887 888
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
M
Megvii Engine Team 已提交
889 890
                MGB_RECORD_EVENT(
                        TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
891 892 893
            }
        }
        bool regen = false;
M
Megvii Engine Team 已提交
894
        for (size_t i = idx; i < cmd.inputs.size(); i++) {
895 896 897 898 899 900
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
M
Megvii Engine Team 已提交
901
                regenerate(p);  // add ApplyOp to the stack
902 903 904 905
                regen = true;
                break;
            }
        }
M
Megvii Engine Team 已提交
906 907
        if (regen)
            continue;
908
        // the required input tensors are already in memory
M
Megvii Engine Team 已提交
909 910
        auto [cmd_backup, recomp_backup, reason_backup] =
                std::make_tuple(cmd, recomp, reason);
911
        m_apply_stack.pop();
912
        do_apply_op(cmd_backup, reason_backup);
913
        if (recomp_backup) {
M
Megvii Engine Team 已提交
914 915 916
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, recomp_backup->id,
                    TensorCommandKind::ReGen);
917 918
            for (auto o : cmd_backup.outputs) {
                if (o) {
919 920 921 922
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
923
    }
924
    m_applying = false;
925 926
}

927
bool ChannelImpl::auto_evict(size_t force_num) {
928
    auto& state = get_worker_state();
929
    if (!m_dtr.comp_node.valid()) {
930
        return false;
931 932
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
933
    size_t flag = false;
M
Megvii Engine Team 已提交
934 935 936
    while ((state.options.dtr_eviction_threshold > 0 &&
            current_memory > state.options.dtr_eviction_threshold) ||
           force_num > 0) {
937
        MGB_RECORD_EVENT(AutoEvictEvent);
938
        sample_on_device(m_dtr.comp_node, false);
939
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
940
        if (!best) {
941
            MGB_RECORD_EVENT(AutoEvictFinishEvent);
942 943 944 945
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
946
            if (force_num > 0) {
M
Megvii Engine Team 已提交
947
                force_num--;
948 949
            }
            flag = true;
950 951 952 953
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
954
        }
955
        sample_on_device(m_dtr.comp_node, false);
956
        MGB_RECORD_EVENT(AutoEvictFinishEvent);
957
    }
958
    return flag;
959 960
}

961 962
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
M
Megvii Engine Team 已提交
963
    for (auto* user : users) {
964 965
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
M
Megvii Engine Team 已提交
966 967 968 969 970
        for (auto* output : outputs) {
            // When a `ComputePath` is detach from it's input,
            // there is no need to reserve it,
            // so we detach all output of this path
            // to decrease it's `ref_cnt` to zero.
971 972 973 974 975
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
M
Megvii Engine Team 已提交
976 977
            for (auto* input : inputs) {
                input->ref_cnt--;
978
            }
979
        }
980
        // now user is dead
981
    }
982
    mgb_assert(dest->users.empty(), "ComputePath leaking");
983 984
}

985 986 987 988
bool ChannelImpl::check_available() {
    return !m_closed;
}

989 990 991 992 993
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
994
    MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
995
    bool require_host = prop == TensorProp::HostValue;
M
Megvii Engine Team 已提交
996
    auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
997 998
    bool wait_host = false;
    if (require_host && !host_available()) {
999 1000
        // avoid dead lock
        lock.unlock();
1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), GetValue{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    GetValue{info},
            });
        }
1011
        lock.lock();
1012
        wait_host = true;
1013
    }
1014 1015
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
1016
        return require_host ? host_available() : static_cast<bool>(info->ptr);
1017
    });
1018
    MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
1019
    m_waitee = nullptr;
1020
    if (wait_host) {
1021 1022 1023
        auto err = info->ptr->comp_node().check_async_error();
        mgb_assert(!err, "%s", err->what());
    }
1024 1025 1026 1027 1028
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
1029
        MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
1030
        m_cv.notify_all();
1031
    }
1032 1033 1034 1035
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
M
Megvii Engine Team 已提交
1036
    for (auto* handle : m_valid_handle) {
1037 1038
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
1039
    }
1040
    return valid_tensors;
1041 1042
}

1043
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
1044
    bool in_worker = (get_worker_tid() == std::this_thread::get_id());
1045 1046 1047 1048 1049 1050
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
M
Megvii Engine Team 已提交
1051 1052
            if (!evict_suc)
                return false;
1053 1054 1055 1056
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
1057 1058 1059
    if (in_worker) {
        reserve_size(x->size());
    }
1060
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
1061 1062
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
1063 1064 1065 1066 1067 1068 1069 1070
        if (in_worker) {
            while (!suc) {
                if (!auto_evict(1)) {
                    break;
                }
                MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
                MGB_CATCH(MemAllocError&, { continue; });
                suc = true;
1071 1072 1073 1074
            }
        }
        if (!suc) {
            set_log_level(pre_level);
M
Megvii Engine Team 已提交
1075 1076 1077
            mgb_log_warn(
                    "reallocating all cuda memory to alleviate fragmentation, the "
                    "performance may be affected");
1078
            set_log_level(LogLevel::NO_LOG);
1079
            imperative_log_profile_begin("defrag");
1080
            BlobManager::inst()->defrag(x->comp_node());
1081
            imperative_log_profile_end("defrag");
1082
            BlobManager::inst()->alloc_direct(x, x->size());
1083 1084 1085 1086 1087
        }
    });
    set_log_level(pre_level);
}

1088
void ChannelImpl::process_one_task(Command& icmd) {
1089 1090
    using namespace ranges;
    using namespace ranges::views;
1091
    auto& state = get_worker_state();
1092
    auto& options = state.options;
M
Megvii Engine Team 已提交
1093
    // TODO: remove std::visit for support osx 10.12
1094
    auto cmd_visitor = [&](const auto& cmd) {
M
Megvii Engine Team 已提交
1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, Put>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
                                      : Tensor::make(cmd.value);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            produce_tensor(cmd.dest, std::move(value));
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
            sample_on_device(cmd.dest->desc.comp_node, false);
        } else if constexpr (std::is_same_v<T, ApplyOp>) {
            for (auto& i : cmd.inputs) {
1112
                if (mgb_unlikely(i->invalid)) {
M
Megvii Engine Team 已提交
1113 1114 1115
                    MGB_LOCK_GUARD(m_mutex);
                    for (auto& i : cmd.outputs) {
                        i->invalid = true;
1116
                    }
M
Megvii Engine Team 已提交
1117 1118 1119
                    return;
                }
            }
1120 1121 1122 1123 1124 1125 1126 1127
            if (state.options.enable_dtr_auto_drop) {
                m_apply_stack.push({cmd, 0, nullptr, "cmd"});
                flush_apply_stack();
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
                        continue;
                    }
M
Megvii Engine Team 已提交
1128
                    output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1129
                }
1130 1131
            } else {
                do_apply_op(cmd, "cmd");
M
Megvii Engine Team 已提交
1132 1133 1134 1135 1136 1137 1138
            }
            if (state.options.enable_drop && state.options.record_computing_path) {
                auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                    auto& input = std::get<0>(tuple2);
                    auto& output = std::get<1>(tuple2);
                    if (!input->ptr || !output->ptr) {
                        return false;
1139
                    }
M
Megvii Engine Team 已提交
1140 1141 1142 1143 1144 1145 1146
                    return input->ptr->blob()->storage() ==
                           output->ptr->blob()->storage();
                };
                // FIXME: do not use opname as identifier
                auto get_name = [](const OpDef& opdef) {
                    if (auto attr = opdef.try_cast_final<OprAttr>()) {
                        return attr->type.c_str();
1147
                    }
M
Megvii Engine Team 已提交
1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
                    return opdef.dyn_typeinfo()->name;
                };

                auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
                    return info->desc.comp_node != comp_node;
                };

                bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                bool inplace =
                        any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);

                if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                    TensorInfo::ComputePath::make(
1161
                            cmd.id, cmd.op, cmd.inputs, cmd.outputs);
M
Megvii Engine Team 已提交
1162 1163
                    size_t detach_cnt = 0;
                    if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
1164
                        cmd.outputs.size() == 6) {
M
Megvii Engine Team 已提交
1165 1166
                        cmd.outputs[0]->detach_producer();  // detach running_mean
                        cmd.outputs[1]->detach_producer();  // detach running_var
1167
                        for (auto input : cmd.inputs) {
M
Megvii Engine Team 已提交
1168
                            input->ref_cnt -= 2;
1169 1170
                        }
                    }
M
Megvii Engine Team 已提交
1171 1172 1173 1174 1175 1176 1177
                    for (auto output : cmd.outputs) {
                        if (output->producer &&
                            !output->size_exceeds_thd(
                                    state.options.dtr_evictee_minimum_size)) {
                            output->detach_producer();
                            detach_cnt++;
                        }
1178
                    }
M
Megvii Engine Team 已提交
1179 1180
                    for (auto input : cmd.inputs) {
                        input->ref_cnt -= detach_cnt;
1181
                    }
1182
                }
1183
            }
M
Megvii Engine Team 已提交
1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199
        } else if constexpr (std::is_same_v<T, Del>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
            CompNode device = cmd.dest->desc.comp_node;
            uint64_t tensor_id = cmd.dest->id;
            free(cmd.dest);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
            sample_on_device(device, false);
        } else if constexpr (std::is_same_v<T, GetValue>) {
            if (cmd.dest->invalid)
                return;
            imperative_log_profile_begin("GetValue");
            if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                regenerate(cmd.dest);
            }
            cmd.dest->ptr->fetch_value();
1200
            MGB_LOCK_GUARD(m_mutex);
M
Megvii Engine Team 已提交
1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217
            notify_tensor_unsafe(cmd.dest);
            imperative_log_profile_end("GetValue");
        } else if constexpr (std::is_same_v<T, Drop>) {
            if (cmd.dest->invalid)
                return;
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
            do_drop(cmd.dest, true);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
        } else if constexpr (std::is_same_v<T, SetOption>) {
            options.set_option(cmd.key, cmd.value);
        } else if constexpr (std::is_same_v<T, StartProfile>) {
            MGB_RECORD_EVENT(StartProfileEvent);
            CompNode::sync_all();
            for (auto* info : cmd.capture_tensors) {
                MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                if (info->status == TensorInfo::Produced) {
1218
                    // TODO: handle drop
M
Megvii Engine Team 已提交
1219 1220 1221
                    MGB_RECORD_EVENT(
                            TensorProduceEvent, info->id, info->desc.layout,
                            info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
1222 1223
                }
            }
M
Megvii Engine Team 已提交
1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238
            CompNode::foreach ([&](CompNode device) {
                sample_on_device(device, true);
                MGB_RECORD_EVENT_IF(
                        (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                        Timer::record_device(device));
            });
            MGB_RECORD_EVENT(StartProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, StopProfile>) {
            MGB_RECORD_EVENT(StopProfileEvent);
            for (auto* info : cmd.escape_tensors) {
                bool has_value = info->status == TensorInfo::Produced;
                if (has_value) {
                    MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
                }
                MGB_RECORD_EVENT(TensorEraseEvent, info->id);
1239
            }
M
Megvii Engine Team 已提交
1240 1241 1242 1243 1244 1245 1246 1247 1248
            CompNode::foreach (
                    [&](CompNode device) { sample_on_device(device, true); });
            MGB_RECORD_EVENT(StopProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, PushScope>) {
            MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
        } else if constexpr (std::is_same_v<T, PopScope>) {
            MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
        } else {
            static_assert(!std::is_same_v<T, T>);
1249
        }
M
Megvii Engine Team 已提交
1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276
    };
    std::visit(
            [&](const auto& cmd) {
                using T = std::decay_t<decltype(cmd)>;
                if (!options.catch_worker_execption) {
                    cmd_visitor(cmd);
                    return;
                }
                try {
                    cmd_visitor(cmd);
                } catch (...) {
                    MGB_LOCK_GUARD(m_mutex);
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        for (auto oup : cmd.outputs) {
                            oup->invalid = true;
                        }
                    } else if constexpr (std::is_same_v<T, Put>) {
                        cmd.dest->invalid = true;
                    }
                    m_worker_exc = std::current_exception();
                    MGB_RECORD_EVENT(WorkerExceptionEvent);
                    if (m_waitee) {
                        notify_tensor_unsafe(m_waitee);
                    }
                }
            },
            icmd.data);
1277 1278 1279 1280
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1281 1282
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1283 1284
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1285 1286 1287 1288 1289
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1290 1291
    }
}
1292

1293
void ChannelImpl::start_profile() {
1294
    MGB_LOCK_GUARD(m_spin);
1295
    mgb_assert(check_available(), "Channel already closed");
1296 1297
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
1298 1299 1300 1301 1302 1303 1304 1305 1306 1307
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StartProfile{std::move(capture_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StartProfile{std::move(capture_tensors)},
            });
        }
1308
    }
1309 1310
}

1311
void ChannelImpl::stop_profile() {
1312
    MGB_LOCK_GUARD(m_spin);
1313
    mgb_assert(check_available(), "Channel already closed");
1314 1315
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
1316 1317 1318 1319 1320 1321 1322 1323 1324 1325
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StopProfile{std::move(escape_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StopProfile{std::move(escape_tensors)},
            });
        }
1326
    }
1327 1328 1329
}

void ChannelImpl::push_scope(std::string name) {
1330
    MGB_LOCK_GUARD(m_spin);
1331
    mgb_assert(check_available(), "Channel already closed");
1332
    auto& state = get_channel_state();
1333
    state.stack_manager.enter(name);
1334
    MGB_RECORD_EVENT(ScopeEvent, name);
1335 1336 1337 1338 1339 1340 1341 1342 1343 1344
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PushScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PushScope{name},
        });
    }
1345 1346 1347
}

void ChannelImpl::pop_scope(std::string name) {
1348
    MGB_LOCK_GUARD(m_spin);
1349
    mgb_assert(check_available(), "Channel already closed");
1350
    auto& state = get_channel_state();
1351
    state.stack_manager.exit(name);
1352
    MGB_RECORD_EVENT(ScopeFinishEvent, name);
1353 1354 1355 1356 1357 1358 1359 1360 1361 1362
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PopScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PopScope{name},
        });
    }
1363 1364
}

1365
void ChannelImpl::assert_in_channel() {
M
Megvii Engine Team 已提交
1366 1367 1368
    mgb_assert(
            get_worker_tid() != std::this_thread::get_id(),
            "this method cannot be called in worker thread");
1369 1370 1371
}

void ChannelImpl::assert_in_worker() {
M
Megvii Engine Team 已提交
1372 1373 1374
    mgb_assert(
            get_worker_tid() == std::this_thread::get_id(),
            "this method can only be called in worker thread");
1375 1376
}

1377
void ChannelImpl::sample_on_device(CompNode device, bool force) {
1378 1379 1380
    if (!Profiler::is_profiling()) {
        return;
    }
1381 1382
    if (!force) {
        thread_local int last_sample_id = 0;
1383
        int sample_rate = Profiler::get_option("sample_rate", 0);
1384 1385 1386 1387
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
1388
    MGB_RECORD_EVENT(SampleDeviceEvent, device);
1389
    auto [total, free] = device.get_mem_status_bytes();
1390
    MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
1391 1392
}

1393 1394 1395
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
1396
        erase_candidate(i);
1397 1398 1399
    }
}

1400 1401
void ChannelImpl::DynamicSublinear::unpin(
        const SmallVector<TensorInfo*>& vec, WorkerState& state) {
1402 1403
    for (auto i : vec) {
        i->unpin();
1404 1405 1406 1407 1408
        if (i->pinned == 0 &&
            i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
            i->cand_index == UINT_MAX) {
            insert_candidate(i);
        }
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

M
Megvii Engine Team 已提交
1455 1456
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
        bool enable_dtr_sqrt_sampling = false) {
1457 1458 1459
    if (candidates.empty())
        return nullptr;

1460 1461
    double min_msps = -1;
    TensorInfo* best = nullptr;
1462 1463
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
M
Megvii Engine Team 已提交
1464 1465
        while (sz * sz <= candidates.size())
            sz++;
1466
        sz--;
1467 1468 1469
    } else {
        sz = candidates.size();
    }
1470 1471 1472 1473 1474 1475 1476

    size_t ti = rand() % sz;
    for (size_t vi = 0; vi < sz; vi++) {
        if (!enable_dtr_sqrt_sampling) {
            ti = vi;
        }
        auto i = candidates[ti];
1477
        if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
1478
            double neighbor_cost = estimate_neighbor_cost(i);
M
Megvii Engine Team 已提交
1479 1480 1481 1482
            size_t begin_ptr =
                    reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(
                    begin_ptr, begin_ptr + i->ptr->blob()->size());
1483
            double free_mem = side_info.first + side_info.second;
M
Megvii Engine Team 已提交
1484 1485
            double msps = i->eval_func(
                    neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
1486 1487 1488 1489 1490
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1491 1492 1493 1494 1495
        if (enable_dtr_sqrt_sampling) {
            ti += rand() % sz;
            if (ti > candidates.size())
                break;
        }
1496 1497 1498 1499
    }
    return best;
}

M
Megvii Engine Team 已提交
1500 1501
void ChannelImpl::DynamicSublinear::merge(
        std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
1502 1503 1504 1505 1506 1507 1508 1509 1510
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

M
Megvii Engine Team 已提交
1511 1512
std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
        std::shared_ptr<DsuNode>& x) {
1513 1514 1515 1516 1517 1518 1519 1520 1521
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
1522 1523 1524 1525 1526 1527
    // tensor to be inserted must be brand new
    mgb_assert(
            ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
            ptr->cand_index);
    ptr->cand_index = candidates.size();
    candidates.push_back(ptr);
1528 1529 1530 1531 1532 1533
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
1534 1535 1536 1537 1538 1539
    // close dtr will just clear candidates, so nothing to erase
    if (candidates.empty()) {
        ptr->cand_index = UINT_MAX;
        return;
    }
    // some tensors may be erased already, just skip them
1540 1541 1542 1543 1544 1545
    if (ptr->cand_index != UINT_MAX) {
        std::swap(candidates[ptr->cand_index], candidates.back());
        candidates[ptr->cand_index]->cand_index = ptr->cand_index;
        candidates.pop_back();
        ptr->cand_index = UINT_MAX;
    }
1546 1547 1548 1549 1550
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}