interpreter_impl.cpp 54.2 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21
#include "megbrain/imperative/ops/utility.h"
22
#include "megbrain/imperative/utils/stats.h"
23 24
#include "megbrain/imperative/utils/to_string.h"

25
#include "../blob_manager_impl.h"
26 27 28
#include "../event_pool.h"
#include "../op_trait.h"

29 30 31 32 33
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

34
namespace {
M
Megvii Engine Team 已提交
35 36 37 38 39 40 41 42
auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
    SmallVector<uint64_t> tid;
    for (auto* ptinfo : tinfo) {
        tid.push_back(ptinfo->id);
    }
    return tid;
};
}  // namespace
43

44
namespace mgb {
M
Megvii Engine Team 已提交
45
using namespace profiler;
46 47
}

48 49 50 51 52
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
53 54 55 56 57 58 59

namespace mgb {

/**
 * USAGE
 *
 *   header:
60
 *     namespace mgb { void imperative_log_profile(const char* message); }
61 62 63 64 65
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
66
SYMBOL_EXPORT
67
void imperative_log_profile_begin(const char* message) {
68
    MGB_RECORD_EVENT(CustomEvent, std::string{message});
69 70
}

71
SYMBOL_EXPORT
72
void imperative_log_profile_end(const char* message) {
73
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
74 75
}

76
SYMBOL_EXPORT
M
Megvii Engine Team 已提交
77
void imperative_log_profile(const char* message) {
78 79 80 81
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

82 83 84 85
SYMBOL_EXPORT
void imperative_log_profile_begin(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
    MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
M
Megvii Engine Team 已提交
86 87
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
88 89 90 91 92
}

SYMBOL_EXPORT
void imperative_log_profile_end(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
M
Megvii Engine Team 已提交
93 94
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
95 96 97
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
}

M
Megvii Engine Team 已提交
98
}  // namespace mgb
99

100 101 102 103 104 105 106 107 108 109 110 111 112 113
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

114 115 116
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
117
    auto custom_allocator = [&](CompNode device, size_t size) {
118 119 120
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
121 122 123
    };
    OpDef::set_allocator(custom_allocator);
    BlobManager::inst()->set_allocator(custom_allocator);
124 125
}

126
// Do not use m_xxx_state directly
127 128 129
#define m_channel_state
#define m_worker_state

130 131 132 133 134 135 136 137 138
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

139
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
140
    MGB_LOCK_GUARD(m_spin);
141
    mgb_assert(check_available(), "Channel already closed");
142
    auto& state = get_channel_state();
143
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
144
    auto info = put_impl(value, no_cache);
M
Megvii Engine Team 已提交
145
    return reinterpret_cast<Handle>(info);
146 147 148
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
149 150 151 152 153
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
154
    auto info = alloc();
155 156 157 158 159 160
    constexpr int size_threshold = TensorShape::MAX_NDIM;
    init(info, {value.layout(), value.comp_node()});
    if (value.layout().total_nr_elems() <= size_threshold) {
        info->h_value = value;
        info->desc.value = value.proxy_to_default_cpu();
    }
161 162 163 164 165 166 167 168 169 170
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Put{info, value, no_cache},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Put{info, value, no_cache},
        });
    }
171
    if (m_async_level == 0) {
172
        sync_impl();
173
        info->desc.comp_node.sync();
174 175
        auto err = info->desc.comp_node.check_async_error();
        mgb_assert(!err, "%s", err->what());
176
    }
177 178 179
    return info;
}

180
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
181
    MGB_LOCK_GUARD(m_spin);
182
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
183
    return reinterpret_cast<Handle>(put_impl(data, hvalue));
184
}
M
Megvii Engine Team 已提交
185 186
TensorInfo* ChannelImpl::put_impl(
        const DeviceTensorND& data, const HostTensorND& hvalue) {
187
    auto& state = get_channel_state();
188
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
M
Megvii Engine Team 已提交
189
    auto info = alloc();
190
    MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
191
    constexpr int size_threshold = TensorShape::MAX_NDIM;
192
    init(info, {data.layout(), data.comp_node()});
193 194 195
    if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
        info->desc.value = hvalue.proxy_to_default_cpu();
    }
196
    info->ptr = Tensor::make(data, hvalue);
M
Megvii Engine Team 已提交
197 198 199
    MGB_RECORD_EVENT(
            TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
            data.raw_ptr());
200
    info->status = TensorInfo::Produced;
201
    MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
M
Megvii Engine Team 已提交
202 203 204
    return info;
}

205
void ChannelImpl::del(Handle handle) {
206
    MGB_LOCK_GUARD(m_spin);
M
Megvii Engine Team 已提交
207
    if (!check_available()) {
208 209
        return;
    }
210 211 212 213
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
214 215 216
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
217 218 219 220 221 222 223 224 225 226
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Del{info},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Del{info},
        });
    }
227 228
}

229
void ChannelImpl::drop(Handle handle) {
230
    MGB_LOCK_GUARD(m_spin);
231
    mgb_assert(check_available(), "Channel already closed");
232 233
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
M
Megvii Engine Team 已提交
234 235
        mgb_assert(
                m_valid_handle.find(handle) != m_valid_handle.end(),
236
                "invalid handle: %p", handle);
237
        auto* info = reinterpret_cast<TensorInfo*>(handle);
238 239 240 241 242 243 244 245 246 247
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), Drop{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    Drop{info},
            });
        }
248 249 250
    }
}

251
void ChannelImpl::dispatch_default_cpu(
M
Megvii Engine Team 已提交
252
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
253 254
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
255
    auto& state = get_channel_state();
256 257

    auto name = op->trait()->make_name(*op);
258
    auto _ = StackManager::Guard(name, &state.stack_manager);
259

M
Megvii Engine Team 已提交
260 261
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
262
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
263

264 265
    SmallVector<DeviceTensorND> input_tensornds;
    CompNode output_cn;
266 267
    {
        MGB_LOCK_GUARD(m_mutex);
268
        for (auto&& info : input_infos) {
269
            auto input_cn = info->desc.comp_node;
270
            if (!output_cn.valid()) {
271 272 273 274 275 276
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
M
Megvii Engine Team 已提交
277 278
                input_tensornds.emplace_back(
                        info->ptr->get_value().proxy_to_default_cpu());
279
            } else {
280
                // We assign h_value before drop ptr
281 282
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
283 284 285 286 287 288 289 290 291
            }
        }
    }

    SmallVector<DeviceTensorND> output_tensornds;
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
M
Megvii Engine Team 已提交
292 293
        output_tensornds.emplace_back(
                HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
294 295
    }

296
    uint64_t op_id = Profiler::next_id();
297

298 299 300 301 302 303 304 305 306
    if (op->trait()->apply_on_device_tensornd) {
        OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
    } else {
        // proxy to apply_on_physical_tensor
        SmallVector<TensorPtr> input_tensors;
        for (auto&& input_tensornd : input_tensornds) {
            input_tensors.push_back(Tensor::make(
                    input_tensornd, HostTensorND::make_proxy(input_tensornd)));
        }
307 308
        auto output_tensors = OpDef::apply_on_physical_tensor(
                *op, input_tensors, output_descs, validated);
309 310 311 312
        for (size_t i = 0; i < output_tensors.size(); ++i) {
            output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
        }
    }
313 314 315

    SmallVector<TensorInfo*> output_infos;
    for (auto&& tensornd : output_tensornds) {
M
Megvii Engine Team 已提交
316 317
        HostTensorND host_tensornd =
                HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
318
        // use `put` for consistency
319
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
320
        mgb_assert(info->desc.layout.ndim != 0);
321
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
322
        outputs->push_back(reinterpret_cast<Handle>(info));
323
    }
M
Megvii Engine Team 已提交
324
    auto op_info_getter = [op] {
325 326
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
M
Megvii Engine Team 已提交
327
        for (auto&& [key, value] : props) {
328 329 330 331
            op_info[key] = value;
        }
        return op_info;
    };
M
Megvii Engine Team 已提交
332 333 334
    MGB_RECORD_EVENT(
            OpDispatchEvent, op_id, name, op_info_getter, tinfo_to_tid(input_infos),
            tinfo_to_tid(output_infos), state.stack_manager.dump());
335
}
336

337
void ChannelImpl::dispatch_kernel(
M
Megvii Engine Team 已提交
338
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
339 340
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
341
    auto& state = get_channel_state();
342 343
    auto& options = state.options;

344 345 346
    auto name = op->trait()->make_name(*op);
    auto _ = StackManager::Guard{name, &state.stack_manager};

M
Megvii Engine Team 已提交
347 348
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
349
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
350

351 352 353 354
    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());

    outputs->reserve(output_descs.size());
355 356
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
357
        auto info = alloc();
358
        init(info, desc);
359 360 361
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
M
Megvii Engine Team 已提交
362
                                    .proxy_to_comp_node(desc.comp_node);
363
        }
364
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
365
        outputs->push_back(reinterpret_cast<Handle>(info));
366
    }
367 368 369
    ApplyOp cmd{Profiler::next_id(),     std::move(op),
                std::move(input_infos),  std::move(output_infos),
                std::move(output_descs), validated};
370
    if (Profiler::is_profiling()) {
371 372 373 374 375 376 377 378
        auto op_info_getter = [op = cmd.op] {
            std::unordered_map<std::string, std::string> op_info;
            auto props = OpDef::props(*op);
            for (auto&& [key, value] : props) {
                op_info[key] = value;
            }
            return op_info;
        };
379
        MGB_RECORD_EVENT(
380 381
                OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
                tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
382
        m_worker.add_task(
383
                {Profiler::next_id(), std::move(cmd),
384 385 386 387
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
388
                std::move(cmd),
389 390
        });
    }
391
    if (!validated && options.async_level == 1) {
392
        sync_impl();
393
    } else if (options.async_level == 0) {
394
        sync_impl();
395
        // check device error
396
        for (auto&& oup : *outputs) {
397 398
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
399 400
            auto err = info->ptr->comp_node().check_async_error();
            mgb_assert(!err, "%s", err->what());
401
        }
402
    }
403 404 405
}

SmallVector<Handle> ChannelImpl::apply_op(
M
Megvii Engine Team 已提交
406
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
407
    MGB_LOCK_GUARD(m_spin);
408
    mgb_assert(check_available(), "Channel already closed");
409 410 411 412 413 414 415 416 417 418 419
    auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
    if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
        size_t ndim = input->desc.layout.ndim;
        auto& gvs = op->cast_final_safe<GetVarShape>();
        if (gvs.axis == MEGDNN_MAX_NDIM) {
            HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
            DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
            cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
            return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
        }
    }
420 421 422 423
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
M
Megvii Engine Team 已提交
424
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
425
    auto& state = get_channel_state();
426
    for (auto i : inputs) {
M
Megvii Engine Team 已提交
427 428 429
        mgb_assert(
                m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
                i);
430 431 432 433
    }
    SmallVector<TensorInfo*> input_infos;
    SmallVector<LogicalTensorDesc> input_descs;
    {
434
        MGB_LOCK_GUARD(m_info_spin);
435 436
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
M
Megvii Engine Team 已提交
437 438 439
            mgb_assert(
                    !info->invalid,
                    "an input tensor is unusable due to previous error");
440 441 442 443 444 445
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
446
    DispatchMode dispatch_mode = state.options.enable_host_compute
M
Megvii Engine Team 已提交
447 448
                                       ? OpDef::decide_dispatch_mode(*op, input_descs)
                                       : DispatchMode::KERNEL;
449
    switch (dispatch_mode) {
450 451 452 453 454 455 456 457 458
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
459 460 461
    return outputs;
}

462
HostTensorND ChannelImpl::get_value(Handle handle) {
463
    MGB_LOCK_GUARD(m_spin);
464
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
465 466 467
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
468
    auto info = reinterpret_cast<TensorInfo*>(handle);
469
    // donnot use info->value_fetched, it's unsafe
470
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
471
    return wait_tensor(info, TensorProp::HostValue)->get_value();
472 473
}

474
TensorShape ChannelImpl::get_shape(Handle handle) {
475
    MGB_LOCK_GUARD(m_spin);
476
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
477 478 479
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
480 481 482 483
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
484
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
485 486 487 488
    mgb_assert(ret.ndim != 0);
    return ret;
}

489
DType ChannelImpl::get_dtype(Handle handle) {
490
    MGB_LOCK_GUARD(m_spin);
491
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
492 493 494
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
495
    auto info = reinterpret_cast<TensorInfo*>(handle);
496
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
497 498 499 500 501
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

502
CompNode ChannelImpl::get_device(Handle handle) {
503
    MGB_LOCK_GUARD(m_spin);
504
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
505 506 507
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
508
    auto info = reinterpret_cast<TensorInfo*>(handle);
509
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
510 511 512 513 514
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

515
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
516
    MGB_LOCK_GUARD(m_spin);
517
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
518 519 520
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
521
    auto info = reinterpret_cast<TensorInfo*>(handle);
522
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
523 524 525
}

void ChannelImpl::sync() {
526
    MGB_LOCK_GUARD(m_spin);
527
    mgb_assert(check_available(), "Channel already closed");
528 529 530 531
    sync_impl();
}

void ChannelImpl::sync_impl() {
532 533 534 535 536 537
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
538
    MGB_LOCK_GUARD(m_spin);
539 540 541 542
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
M
Megvii Engine Team 已提交
543
    for (auto* handle : valid_handles) {
544
        del_impl(handle);
545 546 547
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
548
    sync_impl();
549
    m_closed = true;
550 551
}

552
size_t ChannelImpl::get_option(std::string name) {
553
    MGB_LOCK_GUARD(m_spin);
554
    mgb_assert(check_available(), "Channel already closed");
555 556
    auto& state = get_channel_state();
    return state.options.get_option(name);
557 558
}

559
void ChannelImpl::set_option(std::string name, size_t value) {
560
    MGB_LOCK_GUARD(m_spin);
561
    mgb_assert(check_available(), "Channel already closed");
562 563
    auto& state = get_channel_state();
    state.options.set_option(name, value);
564 565 566 567 568 569 570 571 572 573
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), SetOption{name, value},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                SetOption{name, value},
        });
    }
574 575
}

576 577 578 579 580 581
void ChannelImpl::clear_candidates() {
    MGB_LOCK_GUARD(m_spin);
    mgb_assert(check_available(), "Channel already closed");
    m_dtr.candidates.clear();
}

582
TensorInfo* ChannelImpl::alloc() {
583
    auto& state = get_channel_state();
M
Megvii Engine Team 已提交
584
    auto info = [this] {
585 586 587 588
        MGB_LOCK_GUARD(m_pool_spin);
        auto* ptr = m_pool.alloc_raw();
        new (ptr) TensorInfo();
        return (TensorInfo*)ptr;
589 590 591
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
592
        size_t tensor_id = state.stack_manager.current()->next_id("tensor");
M
Megvii Engine Team 已提交
593 594
        info->name =
                state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
595
    }
596
    return info;
597 598
}

599
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
M
Megvii Engine Team 已提交
600
    m_valid_handle.insert(reinterpret_cast<Handle>(info));
601
    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
602
    info->status = TensorInfo::Allocated;
603
    info->desc = desc;
604 605
}

M
Megvii Engine Team 已提交
606
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
607 608
    if (!ptr->producer) {
        if (user) {
M
Megvii Engine Team 已提交
609 610 611 612
            mgb_log_warn(
                    "the input that produced tensor %p has been deleted, this drop "
                    "operation will be ignored",
                    ptr);
613 614 615 616 617 618 619
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
620
    ptr->status = TensorInfo::Dropped;
621 622 623
    release_tensor(ptr);
}

624
void ChannelImpl::free(TensorInfo* ptr) {
625 626
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
642
    MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
643
    SmallVector<TensorInfo*> inps;
644 645 646 647 648 649 650 651 652 653 654 655 656
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
657
    MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
658 659 660
}

void ChannelImpl::real_free(TensorInfo* ptr) {
661 662
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
663 664 665 666
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
667 668
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
669
        MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
670
    }
671
    MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
672
    ptr->status = TensorInfo::Deleted;
673
    MGB_LOCK_GUARD(m_pool_spin);
674 675 676
    m_pool.free(ptr);
}

677
ChannelImpl::ChannelImpl() : m_worker(this) {}
678

679 680 681
ChannelImpl::~ChannelImpl() {
    close();
}
682

683
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
684
    auto& state = get_worker_state();
685
    MGB_LOCK_GUARD(m_mutex);
686
    m_dtr.update_used_time(dest);
M
Megvii Engine Team 已提交
687 688
    MGB_RECORD_EVENT(
            TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
689
            ptr->dev_tensor(false).raw_ptr());
690
    // update tensor desc for static infer
691 692 693 694 695 696
    if (dest->desc.layout.ndim) {
        mgb_assert(
                dest->desc.layout.eq_shape(ptr->layout()),
                "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
                ptr->layout().to_string().c_str());
    }
697 698 699 700 701
    // in order to avoid performance impact,
    // memory forwarding is disabled when DTR is enabled
    if (state.options.enable_dtr_auto_drop) {
        ptr->to_contiguous_inplace();
    }
702 703
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
704
    dest->memory = ptr->blob()->size();
705
    dest->ptr = std::move(ptr);
706
    dest->evict_type = EvictType::NONE;
707
    dest->status = TensorInfo::Produced;
708 709
    if (dest->pinned == 0 &&
        dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
710 711
        m_dtr.insert_candidate(dest);
    }
712
    notify_tensor_unsafe(dest);
713 714
}

715
void ChannelImpl::release_tensor(TensorInfo* dest) {
716
    MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
717 718
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
719 720 721 722
    auto& state = get_worker_state();
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
        m_dtr.erase_candidate(dest);
    }
723 724
}

725
void ChannelImpl::regenerate(TensorInfo* dest) {
726
    if (dest->evict_type == EvictType::DROP) {
M
Megvii Engine Team 已提交
727 728
        auto&& path = dest->producer;
        m_apply_stack.push(
729 730 731
                {ApplyOp{path->id, path->op, path->inputs, path->outputs,
                         path->outputs_descs},
                 0, dest, "dtr"});
M
Megvii Engine Team 已提交
732 733
        if (!m_applying)
            flush_apply_stack();
734 735 736
    }
}

737
void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
738 739
    using namespace ranges;
    using namespace ranges::views;
740
    auto& state = get_worker_state();
M
Megvii Engine Team 已提交
741 742
    bool profiling_device =
            Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
743
    uint64_t apply_id = cmd.id;
744
    SmallVector<TensorPtr> inputs;
745
    inputs.reserve(cmd.inputs.size());
746 747 748
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
749
        // refcnt ++, owners: [i->ptr, tensor_inputs]
750
        // tensor_inputs.push_back(i->ptr);
751
        inputs.push_back(i->ptr);
752
    }
M
Megvii Engine Team 已提交
753 754
    if (state.options.enable_dtr_auto_drop &&
        state.options.dtr_eviction_threshold > 0) {
755 756
        auto_evict(0);
    }
M
Megvii Engine Team 已提交
757
    auto apply_on_physical_tensor =
758
            [&](auto&& self, const OpDef& def, SmallVector<TensorPtr>&& inputs,
759 760
                SmallVector<LogicalTensorDesc>& output_descs,
                const bool& validated) -> SmallVector<TensorPtr> {
761
        if (def.trait()->make_forward_graph) {
762 763 764 765 766 767 768 769 770 771
            auto apply_functor = [&](std::shared_ptr<OpDef> op,
                                     SmallVector<TensorPtr> inputs,
                                     size_t nr_outputs) -> SmallVector<TensorPtr> {
                auto opname = op->trait()->make_name(*op);
                imperative_log_profile_begin(opname.c_str());
                auto outputs = self(self, *op, std::move(inputs), output_descs, false);
                imperative_log_profile_end(opname.c_str());
                return outputs;
            };
            auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
772 773
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
M
Megvii Engine Team 已提交
774
            for (auto&& input : inputs) {
775
                input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
776
            }
777
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
778 779
            auto outputs = forward_graph.apply<TensorPtr>(
                    inputs, apply_functor, const_functor);
780 781
            return outputs;
        }
782 783 784 785 786 787 788 789 790 791 792 793
        // Check Input Layout
        // Get the input layout constraints, and if the constraint is not satisfied
        // inplace update the layout and blob to make the tensor contiguous
        auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
        for (size_t idx = 0; idx < inputs.size(); ++idx) {
            auto&& layout_checker = constraints[idx];
            if (layout_checker) {
                inputs[idx]->to_contiguous_inplace(layout_checker);
            }
        }
        return OpDef::apply_on_physical_tensor(
                def, std::move(inputs), output_descs, validated);
794
    };
795
    MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
796 797 798 799
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
800 801 802
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
803
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
804 805 806
            }
        }
    }
M
Megvii Engine Team 已提交
807
    for (auto* input : cmd.inputs) {
808
        auto input_id = input->id;
809 810 811
        MGB_RECORD_EVENT(OpInputEvent, input_id);
        MGB_RECORD_EVENT(TensorUsageEvent, input_id);
        MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
812 813
    }
    // Before wait
M
Megvii Engine Team 已提交
814
    // TODO: split operator wait and execute so that OpWait could be corrected recorded.
815
    // Before execute
M
Megvii Engine Team 已提交
816
    for (auto&& [device, kernel_id] : kernels) {
817
        MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
M
Megvii Engine Team 已提交
818
        MGB_RECORD_EVENT_IF(
819
                profiling_device, RecordDeviceEvent, Timer::record_device(device));
820 821
    }
    // Apply op
822
    SmallVector<LogicalTensorDesc> output_descs;
823 824
    for (auto i : cmd.outputs_descs) {
        output_descs.push_back(i);
825
    }
826
    // Here std::move is REQUIRED for removing duplicated references.
827
    auto outputs = apply_on_physical_tensor(
828 829
            apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
            cmd.validated);
830
    // After execute
M
Megvii Engine Team 已提交
831 832
    for (auto&& [device, kernel_id] : kernels) {
        MGB_RECORD_EVENT_IF(
833
                profiling_device, RecordDeviceEvent, Timer::record_device(device));
834
        MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
835 836
    }
    // End profiling operator
837 838
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
839
        auto output = cmd.outputs[i];
840
        if (mgb_unlikely(output == nullptr)) {
841 842
            MGB_RECORD_EVENT(OpOutputEvent, 0);
            MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
843
        } else if (mgb_unlikely(output->ptr != nullptr)) {
844 845
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
846
        } else {
847
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
848
            produce_tensor(output, outputs[i]);
849
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
850 851 852
            if (Profiler::is_profiling()) {
                sample_on_device(output->desc.comp_node, false);
            }
853 854 855 856 857 858 859 860
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
861
        for (auto i : outputs) {
862
            estimate_compute_time += i->blob()->size();
863 864 865 866 867 868 869
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
870
        m_dtr.unpin(cmd.inputs, state);
871
    }
872
    MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
873
    // End profiling operator
874
}
875

876 877
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
878
    auto& state = get_worker_state();
879
    while (!m_apply_stack.empty()) {
M
Megvii Engine Team 已提交
880 881
        auto& [cmd, idx, recomp, reason] =
                m_apply_stack.top();  // cmd.inputs[0~idx-1] is in memory
882 883 884 885 886
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
M
Megvii Engine Team 已提交
887 888
                MGB_RECORD_EVENT(
                        TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
889 890 891
            }
        }
        bool regen = false;
M
Megvii Engine Team 已提交
892
        for (size_t i = idx; i < cmd.inputs.size(); i++) {
893 894 895 896 897 898
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
M
Megvii Engine Team 已提交
899
                regenerate(p);  // add ApplyOp to the stack
900 901 902 903
                regen = true;
                break;
            }
        }
M
Megvii Engine Team 已提交
904 905
        if (regen)
            continue;
906
        // the required input tensors are already in memory
M
Megvii Engine Team 已提交
907 908
        auto [cmd_backup, recomp_backup, reason_backup] =
                std::make_tuple(cmd, recomp, reason);
909
        m_apply_stack.pop();
910
        do_apply_op(cmd_backup, reason_backup);
911
        if (recomp_backup) {
M
Megvii Engine Team 已提交
912 913 914
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, recomp_backup->id,
                    TensorCommandKind::ReGen);
915 916
            for (auto o : cmd_backup.outputs) {
                if (o) {
917 918 919 920
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
921
    }
922
    m_applying = false;
923 924
}

925
bool ChannelImpl::auto_evict(size_t force_num) {
926
    auto& state = get_worker_state();
927
    if (!m_dtr.comp_node.valid()) {
928
        return false;
929 930
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
931
    size_t flag = false;
M
Megvii Engine Team 已提交
932 933 934
    while ((state.options.dtr_eviction_threshold > 0 &&
            current_memory > state.options.dtr_eviction_threshold) ||
           force_num > 0) {
935
        MGB_RECORD_EVENT(AutoEvictEvent);
936
        sample_on_device(m_dtr.comp_node, false);
937
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
938
        if (!best) {
939
            MGB_RECORD_EVENT(AutoEvictFinishEvent);
940 941 942 943
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
944
            if (force_num > 0) {
M
Megvii Engine Team 已提交
945
                force_num--;
946 947
            }
            flag = true;
948 949 950 951
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
952
        }
953
        sample_on_device(m_dtr.comp_node, false);
954
        MGB_RECORD_EVENT(AutoEvictFinishEvent);
955
    }
956
    return flag;
957 958
}

959 960
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
M
Megvii Engine Team 已提交
961
    for (auto* user : users) {
962 963
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
M
Megvii Engine Team 已提交
964 965 966 967 968
        for (auto* output : outputs) {
            // When a `ComputePath` is detach from it's input,
            // there is no need to reserve it,
            // so we detach all output of this path
            // to decrease it's `ref_cnt` to zero.
969 970 971 972 973
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
M
Megvii Engine Team 已提交
974 975
            for (auto* input : inputs) {
                input->ref_cnt--;
976
            }
977
        }
978
        // now user is dead
979
    }
980
    mgb_assert(dest->users.empty(), "ComputePath leaking");
981 982
}

983 984 985 986
bool ChannelImpl::check_available() {
    return !m_closed;
}

987 988 989 990 991
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
992
    MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
993
    bool require_host = prop == TensorProp::HostValue;
M
Megvii Engine Team 已提交
994
    auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
995 996
    bool wait_host = false;
    if (require_host && !host_available()) {
997 998
        // avoid dead lock
        lock.unlock();
999 1000 1001 1002 1003 1004 1005 1006 1007 1008
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), GetValue{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    GetValue{info},
            });
        }
1009
        lock.lock();
1010
        wait_host = true;
1011
    }
1012 1013
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
1014
        return require_host ? host_available() : static_cast<bool>(info->ptr);
1015
    });
1016
    MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
1017
    m_waitee = nullptr;
1018
    if (wait_host) {
1019 1020 1021
        auto err = info->ptr->comp_node().check_async_error();
        mgb_assert(!err, "%s", err->what());
    }
1022 1023 1024 1025 1026
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
1027
        MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
1028
        m_cv.notify_all();
1029
    }
1030 1031 1032 1033
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
M
Megvii Engine Team 已提交
1034
    for (auto* handle : m_valid_handle) {
1035 1036
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
1037
    }
1038
    return valid_tensors;
1039 1040
}

1041
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
1042
    bool in_worker = (get_worker_tid() == std::this_thread::get_id());
1043 1044 1045 1046 1047 1048
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
M
Megvii Engine Team 已提交
1049 1050
            if (!evict_suc)
                return false;
1051 1052 1053 1054
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
1055 1056 1057
    if (in_worker) {
        reserve_size(x->size());
    }
1058
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
1059 1060
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
1061 1062 1063 1064 1065 1066 1067 1068
        if (in_worker) {
            while (!suc) {
                if (!auto_evict(1)) {
                    break;
                }
                MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
                MGB_CATCH(MemAllocError&, { continue; });
                suc = true;
1069 1070 1071 1072
            }
        }
        if (!suc) {
            set_log_level(pre_level);
M
Megvii Engine Team 已提交
1073 1074 1075
            mgb_log_warn(
                    "reallocating all cuda memory to alleviate fragmentation, the "
                    "performance may be affected");
1076
            set_log_level(LogLevel::NO_LOG);
1077
            imperative_log_profile_begin("defrag");
1078
            BlobManager::inst()->defrag(x->comp_node());
1079
            imperative_log_profile_end("defrag");
1080
            BlobManager::inst()->alloc_direct(x, x->size());
1081 1082 1083 1084 1085
        }
    });
    set_log_level(pre_level);
}

1086
void ChannelImpl::process_one_task(Command& icmd) {
1087 1088
    using namespace ranges;
    using namespace ranges::views;
1089
    auto& state = get_worker_state();
1090
    auto& options = state.options;
M
Megvii Engine Team 已提交
1091
    // TODO: remove std::visit for support osx 10.12
1092
    auto cmd_visitor = [&](const auto& cmd) {
M
Megvii Engine Team 已提交
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, Put>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
                                      : Tensor::make(cmd.value);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            produce_tensor(cmd.dest, std::move(value));
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
            sample_on_device(cmd.dest->desc.comp_node, false);
        } else if constexpr (std::is_same_v<T, ApplyOp>) {
            for (auto& i : cmd.inputs) {
1110
                if (mgb_unlikely(i->invalid)) {
M
Megvii Engine Team 已提交
1111 1112 1113
                    MGB_LOCK_GUARD(m_mutex);
                    for (auto& i : cmd.outputs) {
                        i->invalid = true;
1114
                    }
M
Megvii Engine Team 已提交
1115 1116 1117
                    return;
                }
            }
1118 1119 1120 1121 1122 1123 1124 1125
            if (state.options.enable_dtr_auto_drop) {
                m_apply_stack.push({cmd, 0, nullptr, "cmd"});
                flush_apply_stack();
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
                        continue;
                    }
M
Megvii Engine Team 已提交
1126
                    output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1127
                }
1128 1129
            } else {
                do_apply_op(cmd, "cmd");
M
Megvii Engine Team 已提交
1130 1131 1132 1133 1134 1135 1136
            }
            if (state.options.enable_drop && state.options.record_computing_path) {
                auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                    auto& input = std::get<0>(tuple2);
                    auto& output = std::get<1>(tuple2);
                    if (!input->ptr || !output->ptr) {
                        return false;
1137
                    }
M
Megvii Engine Team 已提交
1138 1139 1140 1141 1142 1143 1144
                    return input->ptr->blob()->storage() ==
                           output->ptr->blob()->storage();
                };
                // FIXME: do not use opname as identifier
                auto get_name = [](const OpDef& opdef) {
                    if (auto attr = opdef.try_cast_final<OprAttr>()) {
                        return attr->type.c_str();
1145
                    }
M
Megvii Engine Team 已提交
1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158
                    return opdef.dyn_typeinfo()->name;
                };

                auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
                    return info->desc.comp_node != comp_node;
                };

                bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                bool inplace =
                        any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);

                if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                    TensorInfo::ComputePath::make(
1159
                            cmd.id, cmd.op, cmd.inputs, cmd.outputs, cmd.outputs_descs);
M
Megvii Engine Team 已提交
1160 1161
                    size_t detach_cnt = 0;
                    if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
1162
                        cmd.outputs.size() == 6) {
M
Megvii Engine Team 已提交
1163 1164
                        cmd.outputs[0]->detach_producer();  // detach running_mean
                        cmd.outputs[1]->detach_producer();  // detach running_var
1165
                        for (auto input : cmd.inputs) {
M
Megvii Engine Team 已提交
1166
                            input->ref_cnt -= 2;
1167 1168
                        }
                    }
M
Megvii Engine Team 已提交
1169 1170 1171 1172 1173 1174 1175
                    for (auto output : cmd.outputs) {
                        if (output->producer &&
                            !output->size_exceeds_thd(
                                    state.options.dtr_evictee_minimum_size)) {
                            output->detach_producer();
                            detach_cnt++;
                        }
1176
                    }
M
Megvii Engine Team 已提交
1177 1178
                    for (auto input : cmd.inputs) {
                        input->ref_cnt -= detach_cnt;
1179
                    }
1180
                }
1181
            }
M
Megvii Engine Team 已提交
1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197
        } else if constexpr (std::is_same_v<T, Del>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
            CompNode device = cmd.dest->desc.comp_node;
            uint64_t tensor_id = cmd.dest->id;
            free(cmd.dest);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
            sample_on_device(device, false);
        } else if constexpr (std::is_same_v<T, GetValue>) {
            if (cmd.dest->invalid)
                return;
            imperative_log_profile_begin("GetValue");
            if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                regenerate(cmd.dest);
            }
            cmd.dest->ptr->fetch_value();
1198
            MGB_LOCK_GUARD(m_mutex);
M
Megvii Engine Team 已提交
1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215
            notify_tensor_unsafe(cmd.dest);
            imperative_log_profile_end("GetValue");
        } else if constexpr (std::is_same_v<T, Drop>) {
            if (cmd.dest->invalid)
                return;
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
            do_drop(cmd.dest, true);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
        } else if constexpr (std::is_same_v<T, SetOption>) {
            options.set_option(cmd.key, cmd.value);
        } else if constexpr (std::is_same_v<T, StartProfile>) {
            MGB_RECORD_EVENT(StartProfileEvent);
            CompNode::sync_all();
            for (auto* info : cmd.capture_tensors) {
                MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                if (info->status == TensorInfo::Produced) {
1216
                    // TODO: handle drop
M
Megvii Engine Team 已提交
1217 1218 1219
                    MGB_RECORD_EVENT(
                            TensorProduceEvent, info->id, info->desc.layout,
                            info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
1220 1221
                }
            }
M
Megvii Engine Team 已提交
1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236
            CompNode::foreach ([&](CompNode device) {
                sample_on_device(device, true);
                MGB_RECORD_EVENT_IF(
                        (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                        Timer::record_device(device));
            });
            MGB_RECORD_EVENT(StartProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, StopProfile>) {
            MGB_RECORD_EVENT(StopProfileEvent);
            for (auto* info : cmd.escape_tensors) {
                bool has_value = info->status == TensorInfo::Produced;
                if (has_value) {
                    MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
                }
                MGB_RECORD_EVENT(TensorEraseEvent, info->id);
1237
            }
M
Megvii Engine Team 已提交
1238 1239 1240 1241 1242 1243 1244 1245 1246
            CompNode::foreach (
                    [&](CompNode device) { sample_on_device(device, true); });
            MGB_RECORD_EVENT(StopProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, PushScope>) {
            MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
        } else if constexpr (std::is_same_v<T, PopScope>) {
            MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
        } else {
            static_assert(!std::is_same_v<T, T>);
1247
        }
M
Megvii Engine Team 已提交
1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274
    };
    std::visit(
            [&](const auto& cmd) {
                using T = std::decay_t<decltype(cmd)>;
                if (!options.catch_worker_execption) {
                    cmd_visitor(cmd);
                    return;
                }
                try {
                    cmd_visitor(cmd);
                } catch (...) {
                    MGB_LOCK_GUARD(m_mutex);
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        for (auto oup : cmd.outputs) {
                            oup->invalid = true;
                        }
                    } else if constexpr (std::is_same_v<T, Put>) {
                        cmd.dest->invalid = true;
                    }
                    m_worker_exc = std::current_exception();
                    MGB_RECORD_EVENT(WorkerExceptionEvent);
                    if (m_waitee) {
                        notify_tensor_unsafe(m_waitee);
                    }
                }
            },
            icmd.data);
1275 1276 1277 1278
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1279 1280
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1281 1282
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1283 1284 1285 1286 1287
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1288 1289
    }
}
1290

1291
void ChannelImpl::start_profile() {
1292
    MGB_LOCK_GUARD(m_spin);
1293
    mgb_assert(check_available(), "Channel already closed");
1294 1295
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
1296 1297 1298 1299 1300 1301 1302 1303 1304 1305
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StartProfile{std::move(capture_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StartProfile{std::move(capture_tensors)},
            });
        }
1306
    }
1307 1308
}

1309
void ChannelImpl::stop_profile() {
1310
    MGB_LOCK_GUARD(m_spin);
1311
    mgb_assert(check_available(), "Channel already closed");
1312 1313
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StopProfile{std::move(escape_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StopProfile{std::move(escape_tensors)},
            });
        }
1324
    }
1325 1326 1327
}

void ChannelImpl::push_scope(std::string name) {
1328
    MGB_LOCK_GUARD(m_spin);
1329
    mgb_assert(check_available(), "Channel already closed");
1330
    auto& state = get_channel_state();
1331
    state.stack_manager.enter(name);
1332
    MGB_RECORD_EVENT(ScopeEvent, name);
1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PushScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PushScope{name},
        });
    }
1343 1344 1345
}

void ChannelImpl::pop_scope(std::string name) {
1346
    MGB_LOCK_GUARD(m_spin);
1347
    mgb_assert(check_available(), "Channel already closed");
1348
    auto& state = get_channel_state();
1349
    state.stack_manager.exit(name);
1350
    MGB_RECORD_EVENT(ScopeFinishEvent, name);
1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PopScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PopScope{name},
        });
    }
1361 1362
}

1363
void ChannelImpl::assert_in_channel() {
M
Megvii Engine Team 已提交
1364 1365 1366
    mgb_assert(
            get_worker_tid() != std::this_thread::get_id(),
            "this method cannot be called in worker thread");
1367 1368 1369
}

void ChannelImpl::assert_in_worker() {
M
Megvii Engine Team 已提交
1370 1371 1372
    mgb_assert(
            get_worker_tid() == std::this_thread::get_id(),
            "this method can only be called in worker thread");
1373 1374
}

1375
void ChannelImpl::sample_on_device(CompNode device, bool force) {
1376 1377 1378
    if (!Profiler::is_profiling()) {
        return;
    }
1379 1380
    if (!force) {
        thread_local int last_sample_id = 0;
1381
        int sample_rate = Profiler::get_option("sample_rate", 0);
1382 1383 1384 1385
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
1386
    MGB_RECORD_EVENT(SampleDeviceEvent, device);
1387
    auto [total, free] = device.get_mem_status_bytes();
1388
    MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
1389 1390
}

1391 1392 1393
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
1394
        erase_candidate(i);
1395 1396 1397
    }
}

1398 1399
void ChannelImpl::DynamicSublinear::unpin(
        const SmallVector<TensorInfo*>& vec, WorkerState& state) {
1400 1401
    for (auto i : vec) {
        i->unpin();
1402 1403 1404 1405 1406
        if (i->pinned == 0 &&
            i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
            i->cand_index == UINT_MAX) {
            insert_candidate(i);
        }
1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

M
Megvii Engine Team 已提交
1453 1454
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
        bool enable_dtr_sqrt_sampling = false) {
1455 1456 1457
    if (candidates.empty())
        return nullptr;

1458 1459
    double min_msps = -1;
    TensorInfo* best = nullptr;
1460 1461
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
M
Megvii Engine Team 已提交
1462 1463
        while (sz * sz <= candidates.size())
            sz++;
1464
        sz--;
1465 1466 1467
    } else {
        sz = candidates.size();
    }
1468 1469 1470 1471 1472 1473 1474

    size_t ti = rand() % sz;
    for (size_t vi = 0; vi < sz; vi++) {
        if (!enable_dtr_sqrt_sampling) {
            ti = vi;
        }
        auto i = candidates[ti];
1475
        if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
1476
            double neighbor_cost = estimate_neighbor_cost(i);
M
Megvii Engine Team 已提交
1477 1478 1479 1480
            size_t begin_ptr =
                    reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(
                    begin_ptr, begin_ptr + i->ptr->blob()->size());
1481
            double free_mem = side_info.first + side_info.second;
M
Megvii Engine Team 已提交
1482 1483
            double msps = i->eval_func(
                    neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
1484 1485 1486 1487 1488
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1489 1490 1491 1492 1493
        if (enable_dtr_sqrt_sampling) {
            ti += rand() % sz;
            if (ti > candidates.size())
                break;
        }
1494 1495 1496 1497
    }
    return best;
}

M
Megvii Engine Team 已提交
1498 1499
void ChannelImpl::DynamicSublinear::merge(
        std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
1500 1501 1502 1503 1504 1505 1506 1507 1508
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

M
Megvii Engine Team 已提交
1509 1510
std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
        std::shared_ptr<DsuNode>& x) {
1511 1512 1513 1514 1515 1516 1517 1518 1519
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
1520 1521 1522 1523 1524 1525
    // tensor to be inserted must be brand new
    mgb_assert(
            ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
            ptr->cand_index);
    ptr->cand_index = candidates.size();
    candidates.push_back(ptr);
1526 1527 1528 1529 1530 1531
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
1532 1533 1534 1535 1536 1537
    // close dtr will just clear candidates, so nothing to erase
    if (candidates.empty()) {
        ptr->cand_index = UINT_MAX;
        return;
    }
    // some tensors may be erased already, just skip them
1538 1539 1540 1541 1542 1543
    if (ptr->cand_index != UINT_MAX) {
        std::swap(candidates[ptr->cand_index], candidates.back());
        candidates[ptr->cand_index]->cand_index = ptr->cand_index;
        candidates.pop_back();
        ptr->cand_index = UINT_MAX;
    }
1544 1545 1546 1547 1548
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}