interpreter_impl.cpp 54.1 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21
#include "megbrain/imperative/ops/utility.h"
22
#include "megbrain/imperative/utils/stats.h"
23 24
#include "megbrain/imperative/utils/to_string.h"

25
#include "../blob_manager_impl.h"
26 27 28
#include "../event_pool.h"
#include "../op_trait.h"

29 30 31 32 33
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

34
namespace {
M
Megvii Engine Team 已提交
35 36 37 38 39 40 41 42
auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
    SmallVector<uint64_t> tid;
    for (auto* ptinfo : tinfo) {
        tid.push_back(ptinfo->id);
    }
    return tid;
};
}  // namespace
43

44
namespace mgb {
M
Megvii Engine Team 已提交
45
using namespace profiler;
46 47
}

48 49 50 51 52
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
53 54 55 56 57 58 59

namespace mgb {

/**
 * USAGE
 *
 *   header:
60
 *     namespace mgb { void imperative_log_profile(const char* message); }
61 62 63 64 65
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
66
SYMBOL_EXPORT
67
void imperative_log_profile_begin(const char* message) {
68
    MGB_RECORD_EVENT(CustomEvent, std::string{message});
69 70
}

71
SYMBOL_EXPORT
72
void imperative_log_profile_end(const char* message) {
73
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
74 75
}

76
SYMBOL_EXPORT
M
Megvii Engine Team 已提交
77
void imperative_log_profile(const char* message) {
78 79 80 81
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

82 83 84 85
SYMBOL_EXPORT
void imperative_log_profile_begin(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
    MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
M
Megvii Engine Team 已提交
86 87
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
88 89 90 91 92
}

SYMBOL_EXPORT
void imperative_log_profile_end(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
M
Megvii Engine Team 已提交
93 94
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
95 96 97
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
}

M
Megvii Engine Team 已提交
98
}  // namespace mgb
99

100 101 102 103 104 105 106 107 108 109 110 111 112 113
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

114 115 116
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
117
    auto custom_allocator = [&](CompNode device, size_t size) {
118 119 120
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
121 122 123
    };
    OpDef::set_allocator(custom_allocator);
    BlobManager::inst()->set_allocator(custom_allocator);
124 125
}

126
// Do not use m_xxx_state directly
127 128 129
#define m_channel_state
#define m_worker_state

130 131 132 133 134 135 136 137 138
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

139
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
140
    MGB_LOCK_GUARD(m_spin);
141
    mgb_assert(check_available(), "Channel already closed");
142
    auto& state = get_channel_state();
143
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
144
    auto info = put_impl(value, no_cache);
M
Megvii Engine Team 已提交
145
    return reinterpret_cast<Handle>(info);
146 147 148
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
149 150 151 152 153
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
154
    auto info = alloc();
155 156 157 158 159 160
    constexpr int size_threshold = TensorShape::MAX_NDIM;
    init(info, {value.layout(), value.comp_node()});
    if (value.layout().total_nr_elems() <= size_threshold) {
        info->h_value = value;
        info->desc.value = value.proxy_to_default_cpu();
    }
161 162 163 164 165 166 167 168 169 170
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Put{info, value, no_cache},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Put{info, value, no_cache},
        });
    }
171
    if (m_async_level == 0) {
172
        sync_impl();
173
        info->desc.comp_node.sync();
174 175
        auto err = info->desc.comp_node.check_async_error();
        mgb_assert(!err, "%s", err->what());
176
    }
177 178 179
    return info;
}

180
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
181
    MGB_LOCK_GUARD(m_spin);
182
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
183
    return reinterpret_cast<Handle>(put_impl(data, hvalue));
184
}
M
Megvii Engine Team 已提交
185 186
TensorInfo* ChannelImpl::put_impl(
        const DeviceTensorND& data, const HostTensorND& hvalue) {
187
    auto& state = get_channel_state();
188
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
M
Megvii Engine Team 已提交
189
    auto info = alloc();
190
    MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
191
    constexpr int size_threshold = TensorShape::MAX_NDIM;
192
    init(info, {data.layout(), data.comp_node()});
193 194 195
    if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
        info->desc.value = hvalue.proxy_to_default_cpu();
    }
196
    info->ptr = Tensor::make(data, hvalue);
M
Megvii Engine Team 已提交
197 198 199
    MGB_RECORD_EVENT(
            TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
            data.raw_ptr());
200
    info->status = TensorInfo::Produced;
201
    MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
M
Megvii Engine Team 已提交
202 203 204
    return info;
}

205
void ChannelImpl::del(Handle handle) {
206
    MGB_LOCK_GUARD(m_spin);
M
Megvii Engine Team 已提交
207
    if (!check_available()) {
208 209
        return;
    }
210 211 212 213
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
214 215 216
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
217 218 219 220 221 222 223 224 225 226
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Del{info},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Del{info},
        });
    }
227 228
}

229
void ChannelImpl::drop(Handle handle) {
230
    MGB_LOCK_GUARD(m_spin);
231
    mgb_assert(check_available(), "Channel already closed");
232 233
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
M
Megvii Engine Team 已提交
234 235
        mgb_assert(
                m_valid_handle.find(handle) != m_valid_handle.end(),
236
                "invalid handle: %p", handle);
237
        auto* info = reinterpret_cast<TensorInfo*>(handle);
238 239 240 241 242 243 244 245 246 247
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), Drop{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    Drop{info},
            });
        }
248 249 250
    }
}

251
void ChannelImpl::dispatch_default_cpu(
M
Megvii Engine Team 已提交
252
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
253 254
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
255
    auto& state = get_channel_state();
256 257

    auto name = op->trait()->make_name(*op);
258
    auto _ = StackManager::Guard(name, &state.stack_manager);
259

M
Megvii Engine Team 已提交
260 261
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
262
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
263

264 265
    SmallVector<DeviceTensorND> input_tensornds;
    CompNode output_cn;
266 267
    {
        MGB_LOCK_GUARD(m_mutex);
268
        for (auto&& info : input_infos) {
269
            auto input_cn = info->desc.comp_node;
270
            if (!output_cn.valid()) {
271 272 273 274 275 276
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
M
Megvii Engine Team 已提交
277 278
                input_tensornds.emplace_back(
                        info->ptr->get_value().proxy_to_default_cpu());
279
            } else {
280
                // We assign h_value before drop ptr
281 282
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
283 284 285 286 287 288 289 290 291
            }
        }
    }

    SmallVector<DeviceTensorND> output_tensornds;
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
M
Megvii Engine Team 已提交
292 293
        output_tensornds.emplace_back(
                HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
294 295
    }

296
    uint64_t op_id = Profiler::next_id();
297

298 299 300 301 302 303 304 305 306
    if (op->trait()->apply_on_device_tensornd) {
        OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
    } else {
        // proxy to apply_on_physical_tensor
        SmallVector<TensorPtr> input_tensors;
        for (auto&& input_tensornd : input_tensornds) {
            input_tensors.push_back(Tensor::make(
                    input_tensornd, HostTensorND::make_proxy(input_tensornd)));
        }
307 308
        auto output_tensors = OpDef::apply_on_physical_tensor(
                *op, input_tensors, output_descs, validated);
309 310 311 312
        for (size_t i = 0; i < output_tensors.size(); ++i) {
            output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
        }
    }
313 314 315

    SmallVector<TensorInfo*> output_infos;
    for (auto&& tensornd : output_tensornds) {
M
Megvii Engine Team 已提交
316 317
        HostTensorND host_tensornd =
                HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
318
        // use `put` for consistency
319
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
320
        mgb_assert(info->desc.layout.ndim != 0);
321
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
322
        outputs->push_back(reinterpret_cast<Handle>(info));
323
    }
M
Megvii Engine Team 已提交
324
    auto op_info_getter = [op] {
325 326
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
M
Megvii Engine Team 已提交
327
        for (auto&& [key, value] : props) {
328 329 330 331
            op_info[key] = value;
        }
        return op_info;
    };
M
Megvii Engine Team 已提交
332 333 334
    MGB_RECORD_EVENT(
            OpDispatchEvent, op_id, name, op_info_getter, tinfo_to_tid(input_infos),
            tinfo_to_tid(output_infos), state.stack_manager.dump());
335
}
336

337
void ChannelImpl::dispatch_kernel(
M
Megvii Engine Team 已提交
338
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
339 340
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
341
    auto& state = get_channel_state();
342 343
    auto& options = state.options;

344 345 346
    auto name = op->trait()->make_name(*op);
    auto _ = StackManager::Guard{name, &state.stack_manager};

M
Megvii Engine Team 已提交
347 348
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
349
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
350

351 352 353 354
    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());

    outputs->reserve(output_descs.size());
355 356
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
357
        auto info = alloc();
358
        init(info, desc);
359 360 361
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
M
Megvii Engine Team 已提交
362
                                    .proxy_to_comp_node(desc.comp_node);
363
        }
364
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
365
        outputs->push_back(reinterpret_cast<Handle>(info));
366
    }
367 368 369
    ApplyOp cmd{Profiler::next_id(),     std::move(op),
                std::move(input_infos),  std::move(output_infos),
                std::move(output_descs), validated};
370
    if (Profiler::is_profiling()) {
371 372 373 374 375 376 377 378
        auto op_info_getter = [op = cmd.op] {
            std::unordered_map<std::string, std::string> op_info;
            auto props = OpDef::props(*op);
            for (auto&& [key, value] : props) {
                op_info[key] = value;
            }
            return op_info;
        };
379
        MGB_RECORD_EVENT(
380 381
                OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
                tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
382
        m_worker.add_task(
383
                {Profiler::next_id(), std::move(cmd),
384 385 386 387
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
388
                std::move(cmd),
389 390
        });
    }
391
    if (!validated && options.async_level == 1) {
392
        sync_impl();
393
    } else if (options.async_level == 0) {
394
        sync_impl();
395
        // check device error
396
        for (auto&& oup : *outputs) {
397 398
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
399 400
            auto err = info->ptr->comp_node().check_async_error();
            mgb_assert(!err, "%s", err->what());
401
        }
402
    }
403 404 405
}

SmallVector<Handle> ChannelImpl::apply_op(
M
Megvii Engine Team 已提交
406
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
407
    MGB_LOCK_GUARD(m_spin);
408
    mgb_assert(check_available(), "Channel already closed");
409 410 411 412 413 414 415 416 417 418 419
    auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
    if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
        size_t ndim = input->desc.layout.ndim;
        auto& gvs = op->cast_final_safe<GetVarShape>();
        if (gvs.axis == MEGDNN_MAX_NDIM) {
            HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
            DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
            cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
            return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
        }
    }
420 421 422 423
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
M
Megvii Engine Team 已提交
424
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
425
    auto& state = get_channel_state();
426
    for (auto i : inputs) {
M
Megvii Engine Team 已提交
427 428 429
        mgb_assert(
                m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
                i);
430 431 432 433
    }
    SmallVector<TensorInfo*> input_infos;
    SmallVector<LogicalTensorDesc> input_descs;
    {
434
        MGB_LOCK_GUARD(m_info_spin);
435 436
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
M
Megvii Engine Team 已提交
437 438 439
            mgb_assert(
                    !info->invalid,
                    "an input tensor is unusable due to previous error");
440 441 442 443 444 445
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
446
    DispatchMode dispatch_mode = state.options.enable_host_compute
M
Megvii Engine Team 已提交
447 448
                                       ? OpDef::decide_dispatch_mode(*op, input_descs)
                                       : DispatchMode::KERNEL;
449
    switch (dispatch_mode) {
450 451 452 453 454 455 456 457 458
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
459 460 461
    return outputs;
}

462
HostTensorND ChannelImpl::get_value(Handle handle) {
463
    MGB_LOCK_GUARD(m_spin);
464
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
465 466 467
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
468
    auto info = reinterpret_cast<TensorInfo*>(handle);
469
    // donnot use info->value_fetched, it's unsafe
470
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
471
    return wait_tensor(info, TensorProp::HostValue)->get_value();
472 473
}

474
TensorShape ChannelImpl::get_shape(Handle handle) {
475
    MGB_LOCK_GUARD(m_spin);
476
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
477 478 479
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
480 481 482 483
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
484
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
485 486 487 488
    mgb_assert(ret.ndim != 0);
    return ret;
}

489
DType ChannelImpl::get_dtype(Handle handle) {
490
    MGB_LOCK_GUARD(m_spin);
491
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
492 493 494
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
495
    auto info = reinterpret_cast<TensorInfo*>(handle);
496
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
497 498 499 500 501
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

502
CompNode ChannelImpl::get_device(Handle handle) {
503
    MGB_LOCK_GUARD(m_spin);
504
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
505 506 507
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
508
    auto info = reinterpret_cast<TensorInfo*>(handle);
509
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
510 511 512 513 514
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

515
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
516
    MGB_LOCK_GUARD(m_spin);
517
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
518 519 520
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
521
    auto info = reinterpret_cast<TensorInfo*>(handle);
522
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
523 524 525
}

void ChannelImpl::sync() {
526
    MGB_LOCK_GUARD(m_spin);
527
    mgb_assert(check_available(), "Channel already closed");
528 529 530 531
    sync_impl();
}

void ChannelImpl::sync_impl() {
532 533 534 535 536 537
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
538
    MGB_LOCK_GUARD(m_spin);
539 540 541 542
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
M
Megvii Engine Team 已提交
543
    for (auto* handle : valid_handles) {
544
        del_impl(handle);
545 546 547
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
548
    sync_impl();
549
    m_closed = true;
550 551
}

552
size_t ChannelImpl::get_option(std::string name) {
553
    MGB_LOCK_GUARD(m_spin);
554
    mgb_assert(check_available(), "Channel already closed");
555 556
    auto& state = get_channel_state();
    return state.options.get_option(name);
557 558
}

559
void ChannelImpl::set_option(std::string name, size_t value) {
560
    MGB_LOCK_GUARD(m_spin);
561
    mgb_assert(check_available(), "Channel already closed");
562 563
    auto& state = get_channel_state();
    state.options.set_option(name, value);
564 565 566 567 568 569 570 571 572 573
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), SetOption{name, value},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                SetOption{name, value},
        });
    }
574 575
}

576 577 578 579 580 581
void ChannelImpl::clear_candidates() {
    MGB_LOCK_GUARD(m_spin);
    mgb_assert(check_available(), "Channel already closed");
    m_dtr.candidates.clear();
}

582
TensorInfo* ChannelImpl::alloc() {
583
    auto& state = get_channel_state();
M
Megvii Engine Team 已提交
584
    auto info = [this] {
585
        MGB_LOCK_GUARD(m_pool_spin);
586
        return m_pool.alloc();
587 588 589
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
590
        size_t tensor_id = state.stack_manager.current()->next_id("tensor");
M
Megvii Engine Team 已提交
591 592
        info->name =
                state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
593
    }
594
    return info;
595 596
}

597
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
M
Megvii Engine Team 已提交
598
    m_valid_handle.insert(reinterpret_cast<Handle>(info));
599
    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
600
    info->status = TensorInfo::Allocated;
601
    info->desc = desc;
602 603
}

M
Megvii Engine Team 已提交
604
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
605 606
    if (!ptr->producer) {
        if (user) {
M
Megvii Engine Team 已提交
607 608 609 610
            mgb_log_warn(
                    "the input that produced tensor %p has been deleted, this drop "
                    "operation will be ignored",
                    ptr);
611 612 613 614 615 616 617
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
618
    ptr->status = TensorInfo::Dropped;
619 620 621
    release_tensor(ptr);
}

622
void ChannelImpl::free(TensorInfo* ptr) {
623 624
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
625 626 627 628 629 630 631 632 633 634 635 636 637 638 639
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
640
    MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
641
    SmallVector<TensorInfo*> inps;
642 643 644 645 646 647 648 649 650 651 652 653 654
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
655
    MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
656 657 658
}

void ChannelImpl::real_free(TensorInfo* ptr) {
659 660
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
661 662 663 664
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
665 666
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
667
        MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
668
    }
669
    MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
670
    ptr->status = TensorInfo::Deleted;
671
    MGB_LOCK_GUARD(m_pool_spin);
672 673 674
    m_pool.free(ptr);
}

675
ChannelImpl::ChannelImpl() : m_worker(this) {}
676

677 678 679
ChannelImpl::~ChannelImpl() {
    close();
}
680

681
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
682
    auto& state = get_worker_state();
683
    MGB_LOCK_GUARD(m_mutex);
684
    m_dtr.update_used_time(dest);
M
Megvii Engine Team 已提交
685 686
    MGB_RECORD_EVENT(
            TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
687
            ptr->dev_tensor(false).raw_ptr());
688
    // update tensor desc for static infer
689 690 691 692 693 694
    if (dest->desc.layout.ndim) {
        mgb_assert(
                dest->desc.layout.eq_shape(ptr->layout()),
                "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
                ptr->layout().to_string().c_str());
    }
695 696 697 698 699
    // in order to avoid performance impact,
    // memory forwarding is disabled when DTR is enabled
    if (state.options.enable_dtr_auto_drop) {
        ptr->to_contiguous_inplace();
    }
700 701
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
702
    dest->memory = ptr->blob()->size();
703
    dest->ptr = std::move(ptr);
704
    dest->evict_type = EvictType::NONE;
705
    dest->status = TensorInfo::Produced;
706 707
    if (dest->pinned == 0 &&
        dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
708 709
        m_dtr.insert_candidate(dest);
    }
710
    notify_tensor_unsafe(dest);
711 712
}

713
void ChannelImpl::release_tensor(TensorInfo* dest) {
714
    MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
715 716
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
717 718 719 720
    auto& state = get_worker_state();
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
        m_dtr.erase_candidate(dest);
    }
721 722
}

723
void ChannelImpl::regenerate(TensorInfo* dest) {
724
    if (dest->evict_type == EvictType::DROP) {
M
Megvii Engine Team 已提交
725 726
        auto&& path = dest->producer;
        m_apply_stack.push(
727 728 729
                {ApplyOp{path->id, path->op, path->inputs, path->outputs,
                         path->outputs_descs},
                 0, dest, "dtr"});
M
Megvii Engine Team 已提交
730 731
        if (!m_applying)
            flush_apply_stack();
732 733 734
    }
}

735
void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
736 737
    using namespace ranges;
    using namespace ranges::views;
738
    auto& state = get_worker_state();
M
Megvii Engine Team 已提交
739 740
    bool profiling_device =
            Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
741
    uint64_t apply_id = cmd.id;
742
    SmallVector<TensorPtr> inputs;
743
    inputs.reserve(cmd.inputs.size());
744 745 746
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
747
        // refcnt ++, owners: [i->ptr, tensor_inputs]
748
        // tensor_inputs.push_back(i->ptr);
749
        inputs.push_back(i->ptr);
750
    }
M
Megvii Engine Team 已提交
751 752
    if (state.options.enable_dtr_auto_drop &&
        state.options.dtr_eviction_threshold > 0) {
753 754
        auto_evict(0);
    }
M
Megvii Engine Team 已提交
755
    auto apply_on_physical_tensor =
756
            [&](auto&& self, const OpDef& def, SmallVector<TensorPtr>&& inputs,
757 758
                SmallVector<LogicalTensorDesc>& output_descs,
                const bool& validated) -> SmallVector<TensorPtr> {
759
        if (def.trait()->make_forward_graph) {
760 761 762 763 764 765 766 767 768 769
            auto apply_functor = [&](std::shared_ptr<OpDef> op,
                                     SmallVector<TensorPtr> inputs,
                                     size_t nr_outputs) -> SmallVector<TensorPtr> {
                auto opname = op->trait()->make_name(*op);
                imperative_log_profile_begin(opname.c_str());
                auto outputs = self(self, *op, std::move(inputs), output_descs, false);
                imperative_log_profile_end(opname.c_str());
                return outputs;
            };
            auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
770 771
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
M
Megvii Engine Team 已提交
772
            for (auto&& input : inputs) {
773
                input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
774
            }
775
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
776 777
            auto outputs = forward_graph.apply<TensorPtr>(
                    inputs, apply_functor, const_functor);
778 779
            return outputs;
        }
780 781 782 783 784 785 786 787 788 789 790 791
        // Check Input Layout
        // Get the input layout constraints, and if the constraint is not satisfied
        // inplace update the layout and blob to make the tensor contiguous
        auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
        for (size_t idx = 0; idx < inputs.size(); ++idx) {
            auto&& layout_checker = constraints[idx];
            if (layout_checker) {
                inputs[idx]->to_contiguous_inplace(layout_checker);
            }
        }
        return OpDef::apply_on_physical_tensor(
                def, std::move(inputs), output_descs, validated);
792
    };
793
    MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
794 795 796 797
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
798 799 800
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
801
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
802 803 804
            }
        }
    }
M
Megvii Engine Team 已提交
805
    for (auto* input : cmd.inputs) {
806
        auto input_id = input->id;
807 808 809
        MGB_RECORD_EVENT(OpInputEvent, input_id);
        MGB_RECORD_EVENT(TensorUsageEvent, input_id);
        MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
810 811
    }
    // Before wait
M
Megvii Engine Team 已提交
812
    // TODO: split operator wait and execute so that OpWait could be corrected recorded.
813
    // Before execute
M
Megvii Engine Team 已提交
814
    for (auto&& [device, kernel_id] : kernels) {
815
        MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
M
Megvii Engine Team 已提交
816
        MGB_RECORD_EVENT_IF(
817 818
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
819 820
    }
    // Apply op
821
    SmallVector<LogicalTensorDesc> output_descs;
822 823
    for (auto i : cmd.outputs_descs) {
        output_descs.push_back(i);
824
    }
825
    // Here std::move is REQUIRED for removing duplicated references.
826
    auto outputs = apply_on_physical_tensor(
827 828
            apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
            cmd.validated);
829
    // After execute
M
Megvii Engine Team 已提交
830 831
    for (auto&& [device, kernel_id] : kernels) {
        MGB_RECORD_EVENT_IF(
832 833
                (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                Timer::record_device(device));
834
        MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
835 836
    }
    // End profiling operator
837 838
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
839
        auto output = cmd.outputs[i];
840
        if (mgb_unlikely(output == nullptr)) {
841 842
            MGB_RECORD_EVENT(OpOutputEvent, 0);
            MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
843
        } else if (mgb_unlikely(output->ptr != nullptr)) {
844 845
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
846
        } else {
847
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
848
            produce_tensor(output, outputs[i]);
849
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
850
            sample_on_device(output->desc.comp_node, false);
851 852 853 854 855 856 857 858
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
859
        for (auto i : outputs) {
860
            estimate_compute_time += i->blob()->size();
861 862 863 864 865 866 867
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
868
        m_dtr.unpin(cmd.inputs, state);
869
    }
870
    MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
871
    // End profiling operator
872
}
873

874 875
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
876
    auto& state = get_worker_state();
877
    while (!m_apply_stack.empty()) {
M
Megvii Engine Team 已提交
878 879
        auto& [cmd, idx, recomp, reason] =
                m_apply_stack.top();  // cmd.inputs[0~idx-1] is in memory
880 881 882 883 884
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
M
Megvii Engine Team 已提交
885 886
                MGB_RECORD_EVENT(
                        TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
887 888 889
            }
        }
        bool regen = false;
M
Megvii Engine Team 已提交
890
        for (size_t i = idx; i < cmd.inputs.size(); i++) {
891 892 893 894 895 896
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
M
Megvii Engine Team 已提交
897
                regenerate(p);  // add ApplyOp to the stack
898 899 900 901
                regen = true;
                break;
            }
        }
M
Megvii Engine Team 已提交
902 903
        if (regen)
            continue;
904
        // the required input tensors are already in memory
M
Megvii Engine Team 已提交
905 906
        auto [cmd_backup, recomp_backup, reason_backup] =
                std::make_tuple(cmd, recomp, reason);
907
        m_apply_stack.pop();
908
        do_apply_op(cmd_backup, reason_backup);
909
        if (recomp_backup) {
M
Megvii Engine Team 已提交
910 911 912
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, recomp_backup->id,
                    TensorCommandKind::ReGen);
913 914
            for (auto o : cmd_backup.outputs) {
                if (o) {
915 916 917 918
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
919
    }
920
    m_applying = false;
921 922
}

923
bool ChannelImpl::auto_evict(size_t force_num) {
924
    auto& state = get_worker_state();
925
    if (!m_dtr.comp_node.valid()) {
926
        return false;
927 928
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
929
    size_t flag = false;
M
Megvii Engine Team 已提交
930 931 932
    while ((state.options.dtr_eviction_threshold > 0 &&
            current_memory > state.options.dtr_eviction_threshold) ||
           force_num > 0) {
933
        MGB_RECORD_EVENT(AutoEvictEvent);
934
        sample_on_device(m_dtr.comp_node, false);
935
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
936
        if (!best) {
937
            MGB_RECORD_EVENT(AutoEvictFinishEvent);
938 939 940 941
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
942
            if (force_num > 0) {
M
Megvii Engine Team 已提交
943
                force_num--;
944 945
            }
            flag = true;
946 947 948 949
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
950
        }
951
        sample_on_device(m_dtr.comp_node, false);
952
        MGB_RECORD_EVENT(AutoEvictFinishEvent);
953
    }
954
    return flag;
955 956
}

957 958
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
M
Megvii Engine Team 已提交
959
    for (auto* user : users) {
960 961
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
M
Megvii Engine Team 已提交
962 963 964 965 966
        for (auto* output : outputs) {
            // When a `ComputePath` is detach from it's input,
            // there is no need to reserve it,
            // so we detach all output of this path
            // to decrease it's `ref_cnt` to zero.
967 968 969 970 971
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
M
Megvii Engine Team 已提交
972 973
            for (auto* input : inputs) {
                input->ref_cnt--;
974
            }
975
        }
976
        // now user is dead
977
    }
978
    mgb_assert(dest->users.empty(), "ComputePath leaking");
979 980
}

981 982 983 984
bool ChannelImpl::check_available() {
    return !m_closed;
}

985 986 987 988 989
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
990
    MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
991
    bool require_host = prop == TensorProp::HostValue;
M
Megvii Engine Team 已提交
992
    auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
993 994
    bool wait_host = false;
    if (require_host && !host_available()) {
995 996
        // avoid dead lock
        lock.unlock();
997 998 999 1000 1001 1002 1003 1004 1005 1006
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), GetValue{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    GetValue{info},
            });
        }
1007
        lock.lock();
1008
        wait_host = true;
1009
    }
1010 1011
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
1012
        return require_host ? host_available() : static_cast<bool>(info->ptr);
1013
    });
1014
    MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
1015
    m_waitee = nullptr;
1016
    if (wait_host) {
1017 1018 1019
        auto err = info->ptr->comp_node().check_async_error();
        mgb_assert(!err, "%s", err->what());
    }
1020 1021 1022 1023 1024
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
1025
        MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
1026
        m_cv.notify_all();
1027
    }
1028 1029 1030 1031
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
M
Megvii Engine Team 已提交
1032
    for (auto* handle : m_valid_handle) {
1033 1034
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
1035
    }
1036
    return valid_tensors;
1037 1038
}

1039
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
1040
    bool in_worker = (get_worker_tid() == std::this_thread::get_id());
1041 1042 1043 1044 1045 1046
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
M
Megvii Engine Team 已提交
1047 1048
            if (!evict_suc)
                return false;
1049 1050 1051 1052
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
1053 1054 1055
    if (in_worker) {
        reserve_size(x->size());
    }
1056
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
1057 1058
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
1059 1060 1061 1062 1063 1064 1065 1066
        if (in_worker) {
            while (!suc) {
                if (!auto_evict(1)) {
                    break;
                }
                MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
                MGB_CATCH(MemAllocError&, { continue; });
                suc = true;
1067 1068 1069 1070
            }
        }
        if (!suc) {
            set_log_level(pre_level);
M
Megvii Engine Team 已提交
1071 1072 1073
            mgb_log_warn(
                    "reallocating all cuda memory to alleviate fragmentation, the "
                    "performance may be affected");
1074
            set_log_level(LogLevel::NO_LOG);
1075
            imperative_log_profile_begin("defrag");
1076
            BlobManager::inst()->defrag(x->comp_node());
1077
            imperative_log_profile_end("defrag");
1078
            BlobManager::inst()->alloc_direct(x, x->size());
1079 1080 1081 1082 1083
        }
    });
    set_log_level(pre_level);
}

1084
void ChannelImpl::process_one_task(Command& icmd) {
1085 1086
    using namespace ranges;
    using namespace ranges::views;
1087
    auto& state = get_worker_state();
1088
    auto& options = state.options;
M
Megvii Engine Team 已提交
1089
    // TODO: remove std::visit for support osx 10.12
1090
    auto cmd_visitor = [&](const auto& cmd) {
M
Megvii Engine Team 已提交
1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, Put>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
                                      : Tensor::make(cmd.value);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            produce_tensor(cmd.dest, std::move(value));
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
            sample_on_device(cmd.dest->desc.comp_node, false);
        } else if constexpr (std::is_same_v<T, ApplyOp>) {
            for (auto& i : cmd.inputs) {
1108
                if (mgb_unlikely(i->invalid)) {
M
Megvii Engine Team 已提交
1109 1110 1111
                    MGB_LOCK_GUARD(m_mutex);
                    for (auto& i : cmd.outputs) {
                        i->invalid = true;
1112
                    }
M
Megvii Engine Team 已提交
1113 1114 1115
                    return;
                }
            }
1116 1117 1118 1119 1120 1121 1122 1123
            if (state.options.enable_dtr_auto_drop) {
                m_apply_stack.push({cmd, 0, nullptr, "cmd"});
                flush_apply_stack();
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
                        continue;
                    }
M
Megvii Engine Team 已提交
1124
                    output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1125
                }
1126 1127
            } else {
                do_apply_op(cmd, "cmd");
M
Megvii Engine Team 已提交
1128 1129 1130 1131 1132 1133 1134
            }
            if (state.options.enable_drop && state.options.record_computing_path) {
                auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                    auto& input = std::get<0>(tuple2);
                    auto& output = std::get<1>(tuple2);
                    if (!input->ptr || !output->ptr) {
                        return false;
1135
                    }
M
Megvii Engine Team 已提交
1136 1137 1138 1139 1140 1141 1142
                    return input->ptr->blob()->storage() ==
                           output->ptr->blob()->storage();
                };
                // FIXME: do not use opname as identifier
                auto get_name = [](const OpDef& opdef) {
                    if (auto attr = opdef.try_cast_final<OprAttr>()) {
                        return attr->type.c_str();
1143
                    }
M
Megvii Engine Team 已提交
1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156
                    return opdef.dyn_typeinfo()->name;
                };

                auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
                    return info->desc.comp_node != comp_node;
                };

                bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                bool inplace =
                        any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);

                if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                    TensorInfo::ComputePath::make(
1157
                            cmd.id, cmd.op, cmd.inputs, cmd.outputs, cmd.outputs_descs);
M
Megvii Engine Team 已提交
1158 1159
                    size_t detach_cnt = 0;
                    if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
1160
                        cmd.outputs.size() == 6) {
M
Megvii Engine Team 已提交
1161 1162
                        cmd.outputs[0]->detach_producer();  // detach running_mean
                        cmd.outputs[1]->detach_producer();  // detach running_var
1163
                        for (auto input : cmd.inputs) {
M
Megvii Engine Team 已提交
1164
                            input->ref_cnt -= 2;
1165 1166
                        }
                    }
M
Megvii Engine Team 已提交
1167 1168 1169 1170 1171 1172 1173
                    for (auto output : cmd.outputs) {
                        if (output->producer &&
                            !output->size_exceeds_thd(
                                    state.options.dtr_evictee_minimum_size)) {
                            output->detach_producer();
                            detach_cnt++;
                        }
1174
                    }
M
Megvii Engine Team 已提交
1175 1176
                    for (auto input : cmd.inputs) {
                        input->ref_cnt -= detach_cnt;
1177
                    }
1178
                }
1179
            }
M
Megvii Engine Team 已提交
1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
        } else if constexpr (std::is_same_v<T, Del>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
            CompNode device = cmd.dest->desc.comp_node;
            uint64_t tensor_id = cmd.dest->id;
            free(cmd.dest);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
            sample_on_device(device, false);
        } else if constexpr (std::is_same_v<T, GetValue>) {
            if (cmd.dest->invalid)
                return;
            imperative_log_profile_begin("GetValue");
            if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                regenerate(cmd.dest);
            }
            cmd.dest->ptr->fetch_value();
1196
            MGB_LOCK_GUARD(m_mutex);
M
Megvii Engine Team 已提交
1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213
            notify_tensor_unsafe(cmd.dest);
            imperative_log_profile_end("GetValue");
        } else if constexpr (std::is_same_v<T, Drop>) {
            if (cmd.dest->invalid)
                return;
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
            do_drop(cmd.dest, true);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
        } else if constexpr (std::is_same_v<T, SetOption>) {
            options.set_option(cmd.key, cmd.value);
        } else if constexpr (std::is_same_v<T, StartProfile>) {
            MGB_RECORD_EVENT(StartProfileEvent);
            CompNode::sync_all();
            for (auto* info : cmd.capture_tensors) {
                MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                if (info->status == TensorInfo::Produced) {
1214
                    // TODO: handle drop
M
Megvii Engine Team 已提交
1215 1216 1217
                    MGB_RECORD_EVENT(
                            TensorProduceEvent, info->id, info->desc.layout,
                            info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
1218 1219
                }
            }
M
Megvii Engine Team 已提交
1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234
            CompNode::foreach ([&](CompNode device) {
                sample_on_device(device, true);
                MGB_RECORD_EVENT_IF(
                        (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                        Timer::record_device(device));
            });
            MGB_RECORD_EVENT(StartProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, StopProfile>) {
            MGB_RECORD_EVENT(StopProfileEvent);
            for (auto* info : cmd.escape_tensors) {
                bool has_value = info->status == TensorInfo::Produced;
                if (has_value) {
                    MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
                }
                MGB_RECORD_EVENT(TensorEraseEvent, info->id);
1235
            }
M
Megvii Engine Team 已提交
1236 1237 1238 1239 1240 1241 1242 1243 1244
            CompNode::foreach (
                    [&](CompNode device) { sample_on_device(device, true); });
            MGB_RECORD_EVENT(StopProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, PushScope>) {
            MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
        } else if constexpr (std::is_same_v<T, PopScope>) {
            MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
        } else {
            static_assert(!std::is_same_v<T, T>);
1245
        }
M
Megvii Engine Team 已提交
1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272
    };
    std::visit(
            [&](const auto& cmd) {
                using T = std::decay_t<decltype(cmd)>;
                if (!options.catch_worker_execption) {
                    cmd_visitor(cmd);
                    return;
                }
                try {
                    cmd_visitor(cmd);
                } catch (...) {
                    MGB_LOCK_GUARD(m_mutex);
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        for (auto oup : cmd.outputs) {
                            oup->invalid = true;
                        }
                    } else if constexpr (std::is_same_v<T, Put>) {
                        cmd.dest->invalid = true;
                    }
                    m_worker_exc = std::current_exception();
                    MGB_RECORD_EVENT(WorkerExceptionEvent);
                    if (m_waitee) {
                        notify_tensor_unsafe(m_waitee);
                    }
                }
            },
            icmd.data);
1273 1274 1275 1276
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1277 1278
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1279 1280
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1281 1282 1283 1284 1285
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1286 1287
    }
}
1288

1289
void ChannelImpl::start_profile() {
1290
    MGB_LOCK_GUARD(m_spin);
1291
    mgb_assert(check_available(), "Channel already closed");
1292 1293
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
1294 1295 1296 1297 1298 1299 1300 1301 1302 1303
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StartProfile{std::move(capture_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StartProfile{std::move(capture_tensors)},
            });
        }
1304
    }
1305 1306
}

1307
void ChannelImpl::stop_profile() {
1308
    MGB_LOCK_GUARD(m_spin);
1309
    mgb_assert(check_available(), "Channel already closed");
1310 1311
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
1312 1313 1314 1315 1316 1317 1318 1319 1320 1321
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StopProfile{std::move(escape_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StopProfile{std::move(escape_tensors)},
            });
        }
1322
    }
1323 1324 1325
}

void ChannelImpl::push_scope(std::string name) {
1326
    MGB_LOCK_GUARD(m_spin);
1327
    mgb_assert(check_available(), "Channel already closed");
1328
    auto& state = get_channel_state();
1329
    state.stack_manager.enter(name);
1330
    MGB_RECORD_EVENT(ScopeEvent, name);
1331 1332 1333 1334 1335 1336 1337 1338 1339 1340
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PushScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PushScope{name},
        });
    }
1341 1342 1343
}

void ChannelImpl::pop_scope(std::string name) {
1344
    MGB_LOCK_GUARD(m_spin);
1345
    mgb_assert(check_available(), "Channel already closed");
1346
    auto& state = get_channel_state();
1347
    state.stack_manager.exit(name);
1348
    MGB_RECORD_EVENT(ScopeFinishEvent, name);
1349 1350 1351 1352 1353 1354 1355 1356 1357 1358
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PopScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PopScope{name},
        });
    }
1359 1360
}

1361
void ChannelImpl::assert_in_channel() {
M
Megvii Engine Team 已提交
1362 1363 1364
    mgb_assert(
            get_worker_tid() != std::this_thread::get_id(),
            "this method cannot be called in worker thread");
1365 1366 1367
}

void ChannelImpl::assert_in_worker() {
M
Megvii Engine Team 已提交
1368 1369 1370
    mgb_assert(
            get_worker_tid() == std::this_thread::get_id(),
            "this method can only be called in worker thread");
1371 1372
}

1373
void ChannelImpl::sample_on_device(CompNode device, bool force) {
1374 1375 1376
    if (!Profiler::is_profiling()) {
        return;
    }
1377 1378
    if (!force) {
        thread_local int last_sample_id = 0;
1379
        int sample_rate = Profiler::get_option("sample_rate", 0);
1380 1381 1382 1383
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
1384
    MGB_RECORD_EVENT(SampleDeviceEvent, device);
1385
    auto [total, free] = device.get_mem_status_bytes();
1386
    MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
1387 1388
}

1389 1390 1391
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
1392
        erase_candidate(i);
1393 1394 1395
    }
}

1396 1397
void ChannelImpl::DynamicSublinear::unpin(
        const SmallVector<TensorInfo*>& vec, WorkerState& state) {
1398 1399
    for (auto i : vec) {
        i->unpin();
1400 1401 1402 1403 1404
        if (i->pinned == 0 &&
            i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
            i->cand_index == UINT_MAX) {
            insert_candidate(i);
        }
1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

M
Megvii Engine Team 已提交
1451 1452
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
        bool enable_dtr_sqrt_sampling = false) {
1453 1454 1455
    if (candidates.empty())
        return nullptr;

1456 1457
    double min_msps = -1;
    TensorInfo* best = nullptr;
1458 1459
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
M
Megvii Engine Team 已提交
1460 1461
        while (sz * sz <= candidates.size())
            sz++;
1462
        sz--;
1463 1464 1465
    } else {
        sz = candidates.size();
    }
1466 1467 1468 1469 1470 1471 1472

    size_t ti = rand() % sz;
    for (size_t vi = 0; vi < sz; vi++) {
        if (!enable_dtr_sqrt_sampling) {
            ti = vi;
        }
        auto i = candidates[ti];
1473
        if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
1474
            double neighbor_cost = estimate_neighbor_cost(i);
M
Megvii Engine Team 已提交
1475 1476 1477 1478
            size_t begin_ptr =
                    reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(
                    begin_ptr, begin_ptr + i->ptr->blob()->size());
1479
            double free_mem = side_info.first + side_info.second;
M
Megvii Engine Team 已提交
1480 1481
            double msps = i->eval_func(
                    neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
1482 1483 1484 1485 1486
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1487 1488 1489 1490 1491
        if (enable_dtr_sqrt_sampling) {
            ti += rand() % sz;
            if (ti > candidates.size())
                break;
        }
1492 1493 1494 1495
    }
    return best;
}

M
Megvii Engine Team 已提交
1496 1497
void ChannelImpl::DynamicSublinear::merge(
        std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
1498 1499 1500 1501 1502 1503 1504 1505 1506
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

M
Megvii Engine Team 已提交
1507 1508
std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
        std::shared_ptr<DsuNode>& x) {
1509 1510 1511 1512 1513 1514 1515 1516 1517
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
1518 1519 1520 1521 1522 1523
    // tensor to be inserted must be brand new
    mgb_assert(
            ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
            ptr->cand_index);
    ptr->cand_index = candidates.size();
    candidates.push_back(ptr);
1524 1525 1526 1527 1528 1529
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
1530 1531 1532 1533 1534 1535
    // close dtr will just clear candidates, so nothing to erase
    if (candidates.empty()) {
        ptr->cand_index = UINT_MAX;
        return;
    }
    // some tensors may be erased already, just skip them
1536 1537 1538 1539 1540 1541
    if (ptr->cand_index != UINT_MAX) {
        std::swap(candidates[ptr->cand_index], candidates.back());
        candidates[ptr->cand_index]->cand_index = ptr->cand_index;
        candidates.pop_back();
        ptr->cand_index = UINT_MAX;
    }
1542 1543 1544 1545 1546
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}