interpreter_impl.cpp 53.6 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14 15
#include "range/v3/all.hpp"

16
#include "megbrain/common.h"
17 18
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
19 20
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
21
#include "megbrain/imperative/ops/utility.h"
22
#include "megbrain/imperative/utils/stats.h"
23 24
#include "megbrain/imperative/utils/to_string.h"

25
#include "../blob_manager_impl.h"
26 27 28
#include "../event_pool.h"
#include "../op_trait.h"

29 30 31 32 33
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

34
namespace {
M
Megvii Engine Team 已提交
35 36 37 38 39 40 41 42
auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
    SmallVector<uint64_t> tid;
    for (auto* ptinfo : tinfo) {
        tid.push_back(ptinfo->id);
    }
    return tid;
};
}  // namespace
43

44
namespace mgb {
M
Megvii Engine Team 已提交
45
using namespace profiler;
46 47
}

48 49 50 51 52
#if defined(_WIN32) || defined(_WIN64)
#define SYMBOL_EXPORT __declspec(dllexport)
#else
#define SYMBOL_EXPORT __attribute__((visibility("default")))
#endif
53 54 55 56 57 58 59

namespace mgb {

/**
 * USAGE
 *
 *   header:
60
 *     namespace mgb { void imperative_log_profile(const char* message); }
61 62 63 64 65
 *
 *   code:
 *     mgb::imperative_log_profile("MY MESSAGE");
 *
 **/
66
SYMBOL_EXPORT
67
void imperative_log_profile_begin(const char* message) {
68
    MGB_RECORD_EVENT(CustomEvent, std::string{message});
69 70
}

71
SYMBOL_EXPORT
72
void imperative_log_profile_end(const char* message) {
73
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
74 75
}

76
SYMBOL_EXPORT
M
Megvii Engine Team 已提交
77
void imperative_log_profile(const char* message) {
78 79 80 81
    imperative_log_profile_begin(message);
    imperative_log_profile_end(message);
}

82 83 84 85
SYMBOL_EXPORT
void imperative_log_profile_begin(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
    MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
M
Megvii Engine Team 已提交
86 87
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
88 89 90 91 92
}

SYMBOL_EXPORT
void imperative_log_profile_end(const char* message, const char* device) {
    auto comp_node = CompNode::load(device);
M
Megvii Engine Team 已提交
93 94
    MGB_RECORD_EVENT(
            RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
95 96 97
    MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
}

M
Megvii Engine Team 已提交
98
}  // namespace mgb
99

100 101 102 103 104 105 106 107 108 109 110 111 112 113
std::thread::id ChannelImpl::get_worker_tid() {
    return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
    assert_in_channel();
    return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
    assert_in_worker();
    return m_worker_state;
}

114 115 116 117 118 119 120 121 122 123
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
    sys::set_thread_name("worker");
    m_owner->m_worker_state.tid = std::this_thread::get_id();
    OpDef::set_allocator([&](CompNode device, size_t size) {
        auto blob = Blob::make(device, size);
        m_owner->alloc_tensor_with_evict(blob.get());
        return blob->storage();
    });
}

124
// Do not use m_xxx_state directly
125 126 127
#define m_channel_state
#define m_worker_state

128 129 130 131 132 133 134 135 136
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

137
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
138
    MGB_LOCK_GUARD(m_spin);
139
    mgb_assert(check_available(), "Channel already closed");
140
    auto& state = get_channel_state();
141
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
142
    auto info = put_impl(value, no_cache);
M
Megvii Engine Team 已提交
143
    return reinterpret_cast<Handle>(info);
144 145 146
}

TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
147 148 149 150 151
    if (value.empty()) {
        auto layout = value.layout();
        layout.init_contiguous_stride();
        const_cast<HostTensorND&>(value).reset(value.storage(), layout);
    }
152
    auto info = alloc();
153 154 155 156 157 158
    constexpr int size_threshold = TensorShape::MAX_NDIM;
    init(info, {value.layout(), value.comp_node()});
    if (value.layout().total_nr_elems() <= size_threshold) {
        info->h_value = value;
        info->desc.value = value.proxy_to_default_cpu();
    }
159 160 161 162 163 164 165 166 167 168
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Put{info, value, no_cache},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Put{info, value, no_cache},
        });
    }
169
    if (m_async_level == 0) {
170
        sync_impl();
171
        info->desc.comp_node.sync();
172 173
        auto err = info->desc.comp_node.check_async_error();
        mgb_assert(!err, "%s", err->what());
174
    }
175 176 177
    return info;
}

178
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
179
    MGB_LOCK_GUARD(m_spin);
180
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
181
    return reinterpret_cast<Handle>(put_impl(data, hvalue));
182
}
M
Megvii Engine Team 已提交
183 184
TensorInfo* ChannelImpl::put_impl(
        const DeviceTensorND& data, const HostTensorND& hvalue) {
185
    auto& state = get_channel_state();
186
    auto _ = StackManager::Guard{"Put", &state.stack_manager};
M
Megvii Engine Team 已提交
187
    auto info = alloc();
188
    MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
189
    constexpr int size_threshold = TensorShape::MAX_NDIM;
190
    init(info, {data.layout(), data.comp_node()});
191 192 193
    if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
        info->desc.value = hvalue.proxy_to_default_cpu();
    }
194
    info->ptr = Tensor::make(data, hvalue);
M
Megvii Engine Team 已提交
195 196 197
    MGB_RECORD_EVENT(
            TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
            data.raw_ptr());
198
    info->status = TensorInfo::Produced;
199
    MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
M
Megvii Engine Team 已提交
200 201 202
    return info;
}

203
void ChannelImpl::del(Handle handle) {
204
    MGB_LOCK_GUARD(m_spin);
M
Megvii Engine Team 已提交
205
    if (!check_available()) {
206 207
        return;
    }
208 209 210 211
    del_impl(handle);
}

void ChannelImpl::del_impl(Handle handle) {
212 213 214
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    m_valid_handle.erase(handle);
215 216 217 218 219 220 221 222 223 224
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), Del{info},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                Del{info},
        });
    }
225 226
}

227
void ChannelImpl::drop(Handle handle) {
228
    MGB_LOCK_GUARD(m_spin);
229
    mgb_assert(check_available(), "Channel already closed");
230 231
    auto& state = get_channel_state();
    if (state.options.enable_drop) {
M
Megvii Engine Team 已提交
232 233
        mgb_assert(
                m_valid_handle.find(handle) != m_valid_handle.end(),
234
                "invalid handle: %p", handle);
235
        auto* info = reinterpret_cast<TensorInfo*>(handle);
236 237 238 239 240 241 242 243 244 245
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), Drop{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    Drop{info},
            });
        }
246 247 248
    }
}

249
void ChannelImpl::dispatch_default_cpu(
M
Megvii Engine Team 已提交
250
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
251 252
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
253
    auto& state = get_channel_state();
254 255

    auto name = op->trait()->make_name(*op);
256
    auto _ = StackManager::Guard(name, &state.stack_manager);
257

M
Megvii Engine Team 已提交
258 259
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
260
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
261

262 263
    SmallVector<DeviceTensorND> input_tensornds;
    CompNode output_cn;
264 265
    {
        MGB_LOCK_GUARD(m_mutex);
266
        for (auto&& info : input_infos) {
267
            auto input_cn = info->desc.comp_node;
268
            if (!output_cn.valid()) {
269 270 271 272 273 274
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
M
Megvii Engine Team 已提交
275 276
                input_tensornds.emplace_back(
                        info->ptr->get_value().proxy_to_default_cpu());
277
            } else {
278
                // We assign h_value before drop ptr
279 280
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
281 282 283 284 285 286 287 288 289
            }
        }
    }

    SmallVector<DeviceTensorND> output_tensornds;
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
M
Megvii Engine Team 已提交
290 291
        output_tensornds.emplace_back(
                HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
292 293
    }

294
    uint64_t op_id = Profiler::next_id();
295

296 297 298 299 300 301 302 303 304
    if (op->trait()->apply_on_device_tensornd) {
        OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
    } else {
        // proxy to apply_on_physical_tensor
        SmallVector<TensorPtr> input_tensors;
        for (auto&& input_tensornd : input_tensornds) {
            input_tensors.push_back(Tensor::make(
                    input_tensornd, HostTensorND::make_proxy(input_tensornd)));
        }
305 306
        auto output_tensors = OpDef::apply_on_physical_tensor(
                *op, input_tensors, output_descs, validated);
307 308 309 310
        for (size_t i = 0; i < output_tensors.size(); ++i) {
            output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
        }
    }
311 312 313

    SmallVector<TensorInfo*> output_infos;
    for (auto&& tensornd : output_tensornds) {
M
Megvii Engine Team 已提交
314 315
        HostTensorND host_tensornd =
                HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
316
        // use `put` for consistency
317
        auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
318
        mgb_assert(info->desc.layout.ndim != 0);
319
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
320
        outputs->push_back(reinterpret_cast<Handle>(info));
321
    }
M
Megvii Engine Team 已提交
322
    auto op_info_getter = [op] {
323 324
        std::unordered_map<std::string, std::string> op_info;
        auto props = OpDef::props(*op);
M
Megvii Engine Team 已提交
325
        for (auto&& [key, value] : props) {
326 327 328 329
            op_info[key] = value;
        }
        return op_info;
    };
M
Megvii Engine Team 已提交
330 331 332
    MGB_RECORD_EVENT(
            OpDispatchEvent, op_id, name, op_info_getter, tinfo_to_tid(input_infos),
            tinfo_to_tid(output_infos), state.stack_manager.dump());
333
}
334

335
void ChannelImpl::dispatch_kernel(
M
Megvii Engine Team 已提交
336
        std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
337 338
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
339
    auto& state = get_channel_state();
340 341
    auto& options = state.options;

342 343 344
    auto name = op->trait()->make_name(*op);
    auto _ = StackManager::Guard{name, &state.stack_manager};

M
Megvii Engine Team 已提交
345 346
    auto [output_descs, validated] =
            OpDef::infer_output_attrs_fallible(*op, input_descs);
347
    MGB_RECORD_EVENT(ShapeInferEvent, validated);
348

349 350 351 352
    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());

    outputs->reserve(output_descs.size());
353 354
    for (int i = 0; i < output_descs.size(); ++i) {
        auto&& desc = output_descs[i];
355
        auto info = alloc();
356
        init(info, std::move(desc));
357 358 359
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
M
Megvii Engine Team 已提交
360
                                    .proxy_to_comp_node(desc.comp_node);
361
        }
362
        output_infos.push_back(info);
M
Megvii Engine Team 已提交
363
        outputs->push_back(reinterpret_cast<Handle>(info));
364
    }
365 366 367
    ApplyOp cmd{
            Profiler::next_id(), std::move(op), std::move(input_infos),
            std::move(output_infos), validated};
368
    if (Profiler::is_profiling()) {
369 370 371 372 373 374 375 376
        auto op_info_getter = [op = cmd.op] {
            std::unordered_map<std::string, std::string> op_info;
            auto props = OpDef::props(*op);
            for (auto&& [key, value] : props) {
                op_info[key] = value;
            }
            return op_info;
        };
377
        MGB_RECORD_EVENT(
378 379
                OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
                tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
380
        m_worker.add_task(
381
                {Profiler::next_id(), std::move(cmd),
382 383 384 385
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
386
                std::move(cmd),
387 388
        });
    }
389
    if (!validated && options.async_level == 1) {
390
        sync_impl();
391
    } else if (options.async_level == 0) {
392
        sync_impl();
393
        // check device error
394
        for (auto&& oup : *outputs) {
395 396
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
397 398
            auto err = info->ptr->comp_node().check_async_error();
            mgb_assert(!err, "%s", err->what());
399
        }
400
    }
401 402 403
}

SmallVector<Handle> ChannelImpl::apply_op(
M
Megvii Engine Team 已提交
404
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
405
    MGB_LOCK_GUARD(m_spin);
406
    mgb_assert(check_available(), "Channel already closed");
407 408 409 410 411 412 413 414 415 416 417
    auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
    if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
        size_t ndim = input->desc.layout.ndim;
        auto& gvs = op->cast_final_safe<GetVarShape>();
        if (gvs.axis == MEGDNN_MAX_NDIM) {
            HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
            DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
            cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
            return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
        }
    }
418 419 420 421
    return apply_op_impl(std::move(op), inputs);
}

SmallVector<Handle> ChannelImpl::apply_op_impl(
M
Megvii Engine Team 已提交
422
        std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
423
    auto& state = get_channel_state();
424
    for (auto i : inputs) {
M
Megvii Engine Team 已提交
425 426 427
        mgb_assert(
                m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
                i);
428 429 430 431
    }
    SmallVector<TensorInfo*> input_infos;
    SmallVector<LogicalTensorDesc> input_descs;
    {
432
        MGB_LOCK_GUARD(m_info_spin);
433 434
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
M
Megvii Engine Team 已提交
435 436 437
            mgb_assert(
                    !info->invalid,
                    "an input tensor is unusable due to previous error");
438 439 440 441 442 443
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
        }
    }

    SmallVector<Handle> outputs;
444
    DispatchMode dispatch_mode = state.options.enable_host_compute
M
Megvii Engine Team 已提交
445 446
                                       ? OpDef::decide_dispatch_mode(*op, input_descs)
                                       : DispatchMode::KERNEL;
447
    switch (dispatch_mode) {
448 449 450 451 452 453 454 455 456
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
457 458 459
    return outputs;
}

460
HostTensorND ChannelImpl::get_value(Handle handle) {
461
    MGB_LOCK_GUARD(m_spin);
462
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
463 464 465
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
466
    auto info = reinterpret_cast<TensorInfo*>(handle);
467
    // donnot use info->value_fetched, it's unsafe
468
    mgb_assert(!info->invalid, "tensor is unusable due to previous error");
469
    return wait_tensor(info, TensorProp::HostValue)->get_value();
470 471
}

472
TensorShape ChannelImpl::get_shape(Handle handle) {
473
    MGB_LOCK_GUARD(m_spin);
474
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
475 476 477
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
478 479 480 481
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
482
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
483 484 485 486
    mgb_assert(ret.ndim != 0);
    return ret;
}

487
DType ChannelImpl::get_dtype(Handle handle) {
488
    MGB_LOCK_GUARD(m_spin);
489
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
490 491 492
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
493
    auto info = reinterpret_cast<TensorInfo*>(handle);
494
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
495 496 497 498 499
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

500
CompNode ChannelImpl::get_device(Handle handle) {
501
    MGB_LOCK_GUARD(m_spin);
502
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
503 504 505
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
506
    auto info = reinterpret_cast<TensorInfo*>(handle);
507
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
508 509 510 511 512
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

513
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
514
    MGB_LOCK_GUARD(m_spin);
515
    mgb_assert(check_available(), "Channel already closed");
M
Megvii Engine Team 已提交
516 517 518
    mgb_assert(
            m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
            handle);
519
    auto info = reinterpret_cast<TensorInfo*>(handle);
520
    return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
521 522 523
}

void ChannelImpl::sync() {
524
    MGB_LOCK_GUARD(m_spin);
525
    mgb_assert(check_available(), "Channel already closed");
526 527 528 529
    sync_impl();
}

void ChannelImpl::sync_impl() {
530 531 532 533 534 535
    m_worker.wait_all_task_finish();
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
536
    MGB_LOCK_GUARD(m_spin);
537 538 539 540
    if (!check_available()) {
        return;
    }
    std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
M
Megvii Engine Team 已提交
541
    for (auto* handle : valid_handles) {
542
        del_impl(handle);
543 544 545
    }
    mgb_assert(m_valid_handle.empty());
    mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
546
    sync_impl();
547
    m_closed = true;
548 549
}

550
size_t ChannelImpl::get_option(std::string name) {
551
    MGB_LOCK_GUARD(m_spin);
552
    mgb_assert(check_available(), "Channel already closed");
553 554
    auto& state = get_channel_state();
    return state.options.get_option(name);
555 556
}

557
void ChannelImpl::set_option(std::string name, size_t value) {
558
    MGB_LOCK_GUARD(m_spin);
559
    mgb_assert(check_available(), "Channel already closed");
560 561
    auto& state = get_channel_state();
    state.options.set_option(name, value);
562 563 564 565 566 567 568 569 570 571
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), SetOption{name, value},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                SetOption{name, value},
        });
    }
572 573
}

574 575 576 577 578 579
void ChannelImpl::clear_candidates() {
    MGB_LOCK_GUARD(m_spin);
    mgb_assert(check_available(), "Channel already closed");
    m_dtr.candidates.clear();
}

580
TensorInfo* ChannelImpl::alloc() {
581
    auto& state = get_channel_state();
M
Megvii Engine Team 已提交
582
    auto info = [this] {
583 584 585 586
        MGB_LOCK_GUARD(m_pool_spin);
        auto* ptr = m_pool.alloc_raw();
        new (ptr) TensorInfo();
        return (TensorInfo*)ptr;
587 588 589
    }();
    info->id = Profiler::next_id();
    if (Profiler::is_profiling()) {
590
        size_t tensor_id = state.stack_manager.current()->next_id("tensor");
M
Megvii Engine Team 已提交
591 592
        info->name =
                state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
593
    }
594
    return info;
595 596
}

597
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
M
Megvii Engine Team 已提交
598
    m_valid_handle.insert(reinterpret_cast<Handle>(info));
599
    MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
600
    info->status = TensorInfo::Allocated;
601
    info->desc = desc;
602 603
}

M
Megvii Engine Team 已提交
604
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
605 606
    if (!ptr->producer) {
        if (user) {
M
Megvii Engine Team 已提交
607 608 609 610
            mgb_log_warn(
                    "the input that produced tensor %p has been deleted, this drop "
                    "operation will be ignored",
                    ptr);
611 612 613 614 615 616 617
        }
        return;
    }
    if (ptr->evict_type != EvictType::NONE) {
        return;
    }
    ptr->evict_type = EvictType::DROP;
618
    ptr->status = TensorInfo::Dropped;
619 620 621
    release_tensor(ptr);
}

622
void ChannelImpl::free(TensorInfo* ptr) {
623 624
    auto& state = get_worker_state();
    if (state.options.enable_dtr_auto_drop) {
625 626 627 628 629 630 631 632 633 634 635 636 637 638 639
        // Evicting a tensor, rather than freeing it, can avoid pinning
        // potentially exploding amounts of memory and allow us to save
        // more memory.
        ptr->allow_delete = true;
        if (!ptr->ref_cnt) {
            recursive_free(ptr);
        } else {
            do_drop(ptr);
        }
    } else {
        real_free(ptr);
    }
}

void ChannelImpl::recursive_free(TensorInfo* ptr) {
640
    MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
641
    SmallVector<TensorInfo*> inps;
642 643 644 645 646 647 648 649 650 651 652 653 654
    if (ptr->producer) {
        for (auto i : ptr->producer->inputs) {
            if (i && --i->ref_cnt == 0) {
                inps.push_back(i);
            }
        }
    }
    real_free(ptr);
    for (auto i : inps) {
        if (i->allow_delete) {
            recursive_free(i);
        }
    }
655
    MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
656 657 658
}

void ChannelImpl::real_free(TensorInfo* ptr) {
659 660
    auto& state = get_worker_state();
    if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
661 662 663 664
        m_dtr.erase_candidate(ptr);
    }
    detach_users(ptr);
    ptr->detach_producer();
665 666
    bool has_value = ptr->ptr != nullptr;
    if (has_value) {
667
        MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
668
    }
669
    MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
670
    ptr->status = TensorInfo::Deleted;
671
    MGB_LOCK_GUARD(m_pool_spin);
672 673 674
    m_pool.free(ptr);
}

675
ChannelImpl::ChannelImpl() : m_worker(this) {}
676

677 678 679
ChannelImpl::~ChannelImpl() {
    close();
}
680

681
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
682
    auto& state = get_worker_state();
683
    MGB_LOCK_GUARD(m_mutex);
684
    m_dtr.update_used_time(dest);
M
Megvii Engine Team 已提交
685 686
    MGB_RECORD_EVENT(
            TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
687
            ptr->dev_tensor(false).raw_ptr());
688
    // update tensor desc for static infer
689 690 691 692 693 694
    if (dest->desc.layout.ndim) {
        mgb_assert(
                dest->desc.layout.eq_shape(ptr->layout()),
                "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
                ptr->layout().to_string().c_str());
    }
695 696
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
697
    dest->memory = ptr->blob()->size();
698
    dest->ptr = std::move(ptr);
699
    dest->evict_type = EvictType::NONE;
700
    dest->status = TensorInfo::Produced;
701 702
    if (dest->pinned == 0 &&
        dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
703 704
        m_dtr.insert_candidate(dest);
    }
705
    notify_tensor_unsafe(dest);
706 707
}

708
void ChannelImpl::release_tensor(TensorInfo* dest) {
709
    MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
710 711
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
712 713 714 715
    auto& state = get_worker_state();
    if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
        m_dtr.erase_candidate(dest);
    }
716 717
}

718
void ChannelImpl::regenerate(TensorInfo* dest) {
719
    if (dest->evict_type == EvictType::DROP) {
M
Megvii Engine Team 已提交
720 721
        auto&& path = dest->producer;
        m_apply_stack.push(
722
                {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
M
Megvii Engine Team 已提交
723 724 725
                 "dtr"});
        if (!m_applying)
            flush_apply_stack();
726 727 728
    }
}

729
void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
730 731
    using namespace ranges;
    using namespace ranges::views;
732
    auto& state = get_worker_state();
M
Megvii Engine Team 已提交
733 734
    bool profiling_device =
            Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
735
    uint64_t apply_id = cmd.id;
736
    SmallVector<TensorPtr> inputs;
737
    inputs.reserve(cmd.inputs.size());
738 739 740
    // refcnt == 1, owners: [TensorInfo::ptr]
    for (auto i : cmd.inputs) {
        mgb_assert(i->ptr, "Invalid input tensor ptr!");
741
        // refcnt ++, owners: [i->ptr, tensor_inputs]
742
        // tensor_inputs.push_back(i->ptr);
743
        inputs.push_back(i->ptr);
744
    }
M
Megvii Engine Team 已提交
745 746
    if (state.options.enable_dtr_auto_drop &&
        state.options.dtr_eviction_threshold > 0) {
747 748
        auto_evict(0);
    }
M
Megvii Engine Team 已提交
749
    auto apply_on_physical_tensor =
750
            [&](auto&& self, const OpDef& def, SmallVector<TensorPtr>&& inputs,
751 752
                SmallVector<LogicalTensorDesc>& output_descs,
                const bool& validated) -> SmallVector<TensorPtr> {
753
        if (def.trait()->make_forward_graph) {
754 755 756 757 758 759 760 761 762 763
            auto apply_functor = [&](std::shared_ptr<OpDef> op,
                                     SmallVector<TensorPtr> inputs,
                                     size_t nr_outputs) -> SmallVector<TensorPtr> {
                auto opname = op->trait()->make_name(*op);
                imperative_log_profile_begin(opname.c_str());
                auto outputs = self(self, *op, std::move(inputs), output_descs, false);
                imperative_log_profile_end(opname.c_str());
                return outputs;
            };
            auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
764 765
            // apply recursivily
            SmallVector<LogicalTensorDesc> input_descs;
M
Megvii Engine Team 已提交
766
            for (auto&& input : inputs) {
767
                input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
768
            }
769
            auto forward_graph = OpDef::make_forward_graph(def, input_descs);
770 771
            auto outputs = forward_graph.apply<TensorPtr>(
                    inputs, apply_functor, const_functor);
772 773
            return outputs;
        }
774 775 776 777 778 779 780 781 782 783 784 785
        // Check Input Layout
        // Get the input layout constraints, and if the constraint is not satisfied
        // inplace update the layout and blob to make the tensor contiguous
        auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
        for (size_t idx = 0; idx < inputs.size(); ++idx) {
            auto&& layout_checker = constraints[idx];
            if (layout_checker) {
                inputs[idx]->to_contiguous_inplace(layout_checker);
            }
        }
        return OpDef::apply_on_physical_tensor(
                def, std::move(inputs), output_descs, validated);
786
    };
787
    MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
788 789 790 791
    SmallVector<std::pair<CompNode, uint64_t>> kernels;
    if (profiling_device) {
        // Collecting devices
        SmallVector<CompNode> devices;
792 793 794
        for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
            if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
                devices.push_back(i->desc.comp_node);
795
                kernels.push_back({i->desc.comp_node, Profiler::next_id()});
796 797 798
            }
        }
    }
M
Megvii Engine Team 已提交
799
    for (auto* input : cmd.inputs) {
800
        auto input_id = input->id;
801 802 803
        MGB_RECORD_EVENT(OpInputEvent, input_id);
        MGB_RECORD_EVENT(TensorUsageEvent, input_id);
        MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
804 805
    }
    // Before wait
M
Megvii Engine Team 已提交
806
    // TODO: split operator wait and execute so that OpWait could be corrected recorded.
807
    // Before execute
M
Megvii Engine Team 已提交
808
    for (auto&& [device, kernel_id] : kernels) {
809
        MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
M
Megvii Engine Team 已提交
810
        MGB_RECORD_EVENT_IF(
811
                profiling_device, RecordDeviceEvent, Timer::record_device(device));
812 813
    }
    // Apply op
814 815 816 817
    SmallVector<LogicalTensorDesc> output_descs;
    for (auto i : cmd.outputs) {
        output_descs.push_back(i->desc);
    }
818
    // Here std::move is REQUIRED for removing duplicated references.
819
    auto outputs = apply_on_physical_tensor(
820 821
            apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
            cmd.validated);
822
    // After execute
M
Megvii Engine Team 已提交
823 824
    for (auto&& [device, kernel_id] : kernels) {
        MGB_RECORD_EVENT_IF(
825
                profiling_device, RecordDeviceEvent, Timer::record_device(device));
826
        MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
827 828
    }
    // End profiling operator
829 830
    mgb_assert(outputs.size() == cmd.outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
831
        auto output = cmd.outputs[i];
832
        if (mgb_unlikely(output == nullptr)) {
833 834
            MGB_RECORD_EVENT(OpOutputEvent, 0);
            MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
835
        } else if (mgb_unlikely(output->ptr != nullptr)) {
836 837
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
838
        } else {
839
            MGB_RECORD_EVENT(OpOutputEvent, output->id);
840
            produce_tensor(output, outputs[i]);
841
            MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
842 843 844
            if (Profiler::is_profiling()) {
                sample_on_device(output->desc.comp_node, false);
            }
845 846 847 848 849 850 851 852
        }
    }

    if (state.options.enable_dtr_auto_drop) {
        double estimate_compute_time = 0;
        for (auto i : cmd.inputs) {
            estimate_compute_time += i->memory;
        }
853
        for (auto i : outputs) {
854
            estimate_compute_time += i->blob()->size();
855 856 857 858 859 860 861
        }
        m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
        for (auto i : cmd.outputs) {
            if (i != nullptr) {
                i->compute_time = estimate_compute_time;
            }
        }
862
        m_dtr.unpin(cmd.inputs, state);
863
    }
864
    MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
865
    // End profiling operator
866
}
867

868 869
void ChannelImpl::flush_apply_stack() {
    m_applying = true;
870
    auto& state = get_worker_state();
871
    while (!m_apply_stack.empty()) {
M
Megvii Engine Team 已提交
872 873
        auto& [cmd, idx, recomp, reason] =
                m_apply_stack.top();  // cmd.inputs[0~idx-1] is in memory
874 875 876 877 878
        if (idx == 0) {
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.pin(cmd.inputs);
            }
            if (recomp) {
M
Megvii Engine Team 已提交
879 880
                MGB_RECORD_EVENT(
                        TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
881 882 883
            }
        }
        bool regen = false;
M
Megvii Engine Team 已提交
884
        for (size_t i = idx; i < cmd.inputs.size(); i++) {
885 886 887 888 889 890
            auto&& p = cmd.inputs[i];
            if (state.options.enable_dtr_auto_drop) {
                m_dtr.update_used_time(p);
            }
            if (!p->ptr && p->evict_type != EvictType::NONE) {
                idx = i + 1;
M
Megvii Engine Team 已提交
891
                regenerate(p);  // add ApplyOp to the stack
892 893 894 895
                regen = true;
                break;
            }
        }
M
Megvii Engine Team 已提交
896 897
        if (regen)
            continue;
898
        // the required input tensors are already in memory
M
Megvii Engine Team 已提交
899 900
        auto [cmd_backup, recomp_backup, reason_backup] =
                std::make_tuple(cmd, recomp, reason);
901
        m_apply_stack.pop();
902
        do_apply_op(cmd_backup, reason_backup);
903
        if (recomp_backup) {
M
Megvii Engine Team 已提交
904 905 906
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, recomp_backup->id,
                    TensorCommandKind::ReGen);
907 908
            for (auto o : cmd_backup.outputs) {
                if (o) {
909 910 911 912
                    m_dtr.update_dsu_after_recompute(o);
                }
            }
        }
913
    }
914
    m_applying = false;
915 916
}

917
bool ChannelImpl::auto_evict(size_t force_num) {
918
    auto& state = get_worker_state();
919
    if (!m_dtr.comp_node.valid()) {
920
        return false;
921 922
    }
    size_t current_memory = m_dtr.comp_node.get_used_memory();
923
    size_t flag = false;
M
Megvii Engine Team 已提交
924 925 926
    while ((state.options.dtr_eviction_threshold > 0 &&
            current_memory > state.options.dtr_eviction_threshold) ||
           force_num > 0) {
927
        MGB_RECORD_EVENT(AutoEvictEvent);
928
        sample_on_device(m_dtr.comp_node, false);
929
        auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
930
        if (!best) {
931
            MGB_RECORD_EVENT(AutoEvictFinishEvent);
932 933 934 935
            break;
        }
        if (best->ptr.unique() && best->ptr->blob().unique()) {
            current_memory -= best->memory;
936
            if (force_num > 0) {
M
Megvii Engine Team 已提交
937
                force_num--;
938 939
            }
            flag = true;
940 941 942 943
        }
        do_drop(best);
        if (best->evict_type == EvictType::DROP) {
            m_dtr.update_dsu_after_evict(best);
944
        }
945
        sample_on_device(m_dtr.comp_node, false);
946
        MGB_RECORD_EVENT(AutoEvictFinishEvent);
947
    }
948
    return flag;
949 950
}

951 952
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
M
Megvii Engine Team 已提交
953
    for (auto* user : users) {
954 955
        SmallVector<TensorInfo*> outputs = user->outputs;
        SmallVector<TensorInfo*> inputs = user->inputs;
M
Megvii Engine Team 已提交
956 957 958 959 960
        for (auto* output : outputs) {
            // When a `ComputePath` is detach from it's input,
            // there is no need to reserve it,
            // so we detach all output of this path
            // to decrease it's `ref_cnt` to zero.
961 962 963 964 965
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
M
Megvii Engine Team 已提交
966 967
            for (auto* input : inputs) {
                input->ref_cnt--;
968
            }
969
        }
970
        // now user is dead
971
    }
972
    mgb_assert(dest->users.empty(), "ComputePath leaking");
973 974
}

975 976 977 978
bool ChannelImpl::check_available() {
    return !m_closed;
}

979 980 981 982 983
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee, "duplicate waitee");
    m_waitee = info;
    m_waitee_id = Profiler::next_id();
984
    MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
985
    bool require_host = prop == TensorProp::HostValue;
M
Megvii Engine Team 已提交
986
    auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
987 988
    bool wait_host = false;
    if (require_host && !host_available()) {
989 990
        // avoid dead lock
        lock.unlock();
991 992 993 994 995 996 997 998 999 1000
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), GetValue{info},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    GetValue{info},
            });
        }
1001
        lock.lock();
1002
        wait_host = true;
1003
    }
1004 1005
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
1006
        return require_host ? host_available() : static_cast<bool>(info->ptr);
1007
    });
1008
    MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
1009
    m_waitee = nullptr;
1010
    if (wait_host) {
1011 1012 1013
        auto err = info->ptr->comp_node().check_async_error();
        mgb_assert(!err, "%s", err->what());
    }
1014 1015 1016 1017 1018
    return info->ptr;
}

void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
    if (info == m_waitee) {
1019
        MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
1020
        m_cv.notify_all();
1021
    }
1022 1023 1024 1025
}

std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
    std::unordered_set<TensorInfo*> valid_tensors;
M
Megvii Engine Team 已提交
1026
    for (auto* handle : m_valid_handle) {
1027 1028
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        valid_tensors.insert(info);
1029
    }
1030
    return valid_tensors;
1031 1032
}

1033
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
1034 1035 1036 1037 1038 1039
    auto reserve_size = [&](size_t size) {
        if (!m_dtr.comp_node.valid()) {
            return false;
        }
        while (size > m_dtr.comp_node.get_max_block_size_available()) {
            bool evict_suc = auto_evict(1);
M
Megvii Engine Team 已提交
1040 1041
            if (!evict_suc)
                return false;
1042 1043 1044 1045
        }
        return true;
    };
    auto pre_level = set_log_level(LogLevel::NO_LOG);
1046 1047
    reserve_size(x->size());
    MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
1048 1049 1050 1051 1052 1053
    MGB_CATCH(MemAllocError&, {
        bool suc = false;
        while (!suc) {
            if (!auto_evict(1)) {
                break;
            }
1054
            MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
1055 1056 1057 1058 1059
            MGB_CATCH(MemAllocError&, { continue; });
            suc = true;
        }
        if (!suc) {
            set_log_level(pre_level);
M
Megvii Engine Team 已提交
1060 1061 1062
            mgb_log_warn(
                    "reallocating all cuda memory to alleviate fragmentation, the "
                    "performance may be affected");
1063
            set_log_level(LogLevel::NO_LOG);
1064
            imperative_log_profile_begin("defrag");
1065
            BlobManager::inst()->defrag(x->comp_node());
1066
            imperative_log_profile_end("defrag");
1067
            BlobManager::inst()->alloc_direct(x, x->size());
1068 1069 1070 1071 1072
        }
    });
    set_log_level(pre_level);
}

1073
void ChannelImpl::process_one_task(Command& icmd) {
1074 1075
    using namespace ranges;
    using namespace ranges::views;
1076
    auto& state = get_worker_state();
1077
    auto& options = state.options;
M
Megvii Engine Team 已提交
1078
    // TODO: remove std::visit for support osx 10.12
1079
    auto cmd_visitor = [&](const auto& cmd) {
M
Megvii Engine Team 已提交
1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, Put>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
                                      : Tensor::make(cmd.value);
            MGB_RECORD_EVENT_IF(
                    (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                    Timer::record_device(cmd.value.comp_node()));
            produce_tensor(cmd.dest, std::move(value));
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
            sample_on_device(cmd.dest->desc.comp_node, false);
        } else if constexpr (std::is_same_v<T, ApplyOp>) {
            for (auto& i : cmd.inputs) {
1097
                if (mgb_unlikely(i->invalid)) {
M
Megvii Engine Team 已提交
1098 1099 1100
                    MGB_LOCK_GUARD(m_mutex);
                    for (auto& i : cmd.outputs) {
                        i->invalid = true;
1101
                    }
M
Megvii Engine Team 已提交
1102 1103 1104
                    return;
                }
            }
1105 1106 1107 1108 1109 1110 1111 1112
            if (state.options.enable_dtr_auto_drop) {
                m_apply_stack.push({cmd, 0, nullptr, "cmd"});
                flush_apply_stack();
                for (size_t i = 0; i < cmd.outputs.size(); ++i) {
                    auto output = cmd.outputs[i];
                    if (output == nullptr) {
                        continue;
                    }
M
Megvii Engine Team 已提交
1113
                    output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
1114
                }
1115 1116
            } else {
                do_apply_op(cmd, "cmd");
M
Megvii Engine Team 已提交
1117 1118 1119 1120 1121 1122 1123
            }
            if (state.options.enable_drop && state.options.record_computing_path) {
                auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
                    auto& input = std::get<0>(tuple2);
                    auto& output = std::get<1>(tuple2);
                    if (!input->ptr || !output->ptr) {
                        return false;
1124
                    }
M
Megvii Engine Team 已提交
1125 1126 1127 1128 1129 1130 1131
                    return input->ptr->blob()->storage() ==
                           output->ptr->blob()->storage();
                };
                // FIXME: do not use opname as identifier
                auto get_name = [](const OpDef& opdef) {
                    if (auto attr = opdef.try_cast_final<OprAttr>()) {
                        return attr->type.c_str();
1132
                    }
M
Megvii Engine Team 已提交
1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151
                    return opdef.dyn_typeinfo()->name;
                };

                auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
                    return info->desc.comp_node != comp_node;
                };

                bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
                bool inplace =
                        any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);

                if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
                    TensorInfo::ComputePath::make(
                            cmd.id, cmd.op, cmd.inputs, cmd.outputs);
                    size_t detach_cnt = 0;
                    if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
                        cmd.outputs.size() == 5) {
                        cmd.outputs[0]->detach_producer();  // detach running_mean
                        cmd.outputs[1]->detach_producer();  // detach running_var
1152
                        for (auto input : cmd.inputs) {
M
Megvii Engine Team 已提交
1153
                            input->ref_cnt -= 2;
1154 1155
                        }
                    }
M
Megvii Engine Team 已提交
1156 1157 1158 1159 1160 1161 1162
                    for (auto output : cmd.outputs) {
                        if (output->producer &&
                            !output->size_exceeds_thd(
                                    state.options.dtr_evictee_minimum_size)) {
                            output->detach_producer();
                            detach_cnt++;
                        }
1163
                    }
M
Megvii Engine Team 已提交
1164 1165
                    for (auto input : cmd.inputs) {
                        input->ref_cnt -= detach_cnt;
1166
                    }
1167
                }
1168
            }
M
Megvii Engine Team 已提交
1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184
        } else if constexpr (std::is_same_v<T, Del>) {
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
            CompNode device = cmd.dest->desc.comp_node;
            uint64_t tensor_id = cmd.dest->id;
            free(cmd.dest);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
            sample_on_device(device, false);
        } else if constexpr (std::is_same_v<T, GetValue>) {
            if (cmd.dest->invalid)
                return;
            imperative_log_profile_begin("GetValue");
            if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
                regenerate(cmd.dest);
            }
            cmd.dest->ptr->fetch_value();
1185
            MGB_LOCK_GUARD(m_mutex);
M
Megvii Engine Team 已提交
1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202
            notify_tensor_unsafe(cmd.dest);
            imperative_log_profile_end("GetValue");
        } else if constexpr (std::is_same_v<T, Drop>) {
            if (cmd.dest->invalid)
                return;
            MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
            do_drop(cmd.dest, true);
            MGB_RECORD_EVENT(
                    TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
        } else if constexpr (std::is_same_v<T, SetOption>) {
            options.set_option(cmd.key, cmd.value);
        } else if constexpr (std::is_same_v<T, StartProfile>) {
            MGB_RECORD_EVENT(StartProfileEvent);
            CompNode::sync_all();
            for (auto* info : cmd.capture_tensors) {
                MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
                if (info->status == TensorInfo::Produced) {
1203
                    // TODO: handle drop
M
Megvii Engine Team 已提交
1204 1205 1206
                    MGB_RECORD_EVENT(
                            TensorProduceEvent, info->id, info->desc.layout,
                            info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
1207 1208
                }
            }
M
Megvii Engine Team 已提交
1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223
            CompNode::foreach ([&](CompNode device) {
                sample_on_device(device, true);
                MGB_RECORD_EVENT_IF(
                        (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
                        Timer::record_device(device));
            });
            MGB_RECORD_EVENT(StartProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, StopProfile>) {
            MGB_RECORD_EVENT(StopProfileEvent);
            for (auto* info : cmd.escape_tensors) {
                bool has_value = info->status == TensorInfo::Produced;
                if (has_value) {
                    MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
                }
                MGB_RECORD_EVENT(TensorEraseEvent, info->id);
1224
            }
M
Megvii Engine Team 已提交
1225 1226 1227 1228 1229 1230 1231 1232 1233
            CompNode::foreach (
                    [&](CompNode device) { sample_on_device(device, true); });
            MGB_RECORD_EVENT(StopProfileFinishEvent);
        } else if constexpr (std::is_same_v<T, PushScope>) {
            MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
        } else if constexpr (std::is_same_v<T, PopScope>) {
            MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
        } else {
            static_assert(!std::is_same_v<T, T>);
1234
        }
M
Megvii Engine Team 已提交
1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261
    };
    std::visit(
            [&](const auto& cmd) {
                using T = std::decay_t<decltype(cmd)>;
                if (!options.catch_worker_execption) {
                    cmd_visitor(cmd);
                    return;
                }
                try {
                    cmd_visitor(cmd);
                } catch (...) {
                    MGB_LOCK_GUARD(m_mutex);
                    if constexpr (std::is_same_v<T, ApplyOp>) {
                        for (auto oup : cmd.outputs) {
                            oup->invalid = true;
                        }
                    } else if constexpr (std::is_same_v<T, Put>) {
                        cmd.dest->invalid = true;
                    }
                    m_worker_exc = std::current_exception();
                    MGB_RECORD_EVENT(WorkerExceptionEvent);
                    if (m_waitee) {
                        notify_tensor_unsafe(m_waitee);
                    }
                }
            },
            icmd.data);
1262 1263 1264 1265
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
1266 1267
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
1268 1269
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
1270 1271 1272 1273 1274
        try {
            std::rethrow_exception(exc);
        } catch (...) {
            throw AsyncError();
        }
1275 1276
    }
}
1277

1278
void ChannelImpl::start_profile() {
1279
    MGB_LOCK_GUARD(m_spin);
1280
    mgb_assert(check_available(), "Channel already closed");
1281 1282
    auto capture_tensors = collect_valid_tensors();
    if (capture_tensors.size() > 0) {
1283 1284 1285 1286 1287 1288 1289 1290 1291 1292
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StartProfile{std::move(capture_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StartProfile{std::move(capture_tensors)},
            });
        }
1293
    }
1294 1295
}

1296
void ChannelImpl::stop_profile() {
1297
    MGB_LOCK_GUARD(m_spin);
1298
    mgb_assert(check_available(), "Channel already closed");
1299 1300
    auto escape_tensors = collect_valid_tensors();
    if (escape_tensors.size() > 0) {
1301 1302 1303 1304 1305 1306 1307 1308 1309 1310
        if (Profiler::is_profiling()) {
            m_worker.add_task(
                    {Profiler::next_id(), StopProfile{std::move(escape_tensors)},
                     get_channel_state().stack_manager.dump()});
        } else {
            m_worker.add_task({
                    Profiler::next_id(),
                    StopProfile{std::move(escape_tensors)},
            });
        }
1311
    }
1312 1313 1314
}

void ChannelImpl::push_scope(std::string name) {
1315
    MGB_LOCK_GUARD(m_spin);
1316
    mgb_assert(check_available(), "Channel already closed");
1317
    auto& state = get_channel_state();
1318
    state.stack_manager.enter(name);
1319
    MGB_RECORD_EVENT(ScopeEvent, name);
1320 1321 1322 1323 1324 1325 1326 1327 1328 1329
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PushScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PushScope{name},
        });
    }
1330 1331 1332
}

void ChannelImpl::pop_scope(std::string name) {
1333
    MGB_LOCK_GUARD(m_spin);
1334
    mgb_assert(check_available(), "Channel already closed");
1335
    auto& state = get_channel_state();
1336
    state.stack_manager.exit(name);
1337
    MGB_RECORD_EVENT(ScopeFinishEvent, name);
1338 1339 1340 1341 1342 1343 1344 1345 1346 1347
    if (Profiler::is_profiling()) {
        m_worker.add_task(
                {Profiler::next_id(), PopScope{name},
                 get_channel_state().stack_manager.dump()});
    } else {
        m_worker.add_task({
                Profiler::next_id(),
                PopScope{name},
        });
    }
1348 1349
}

1350
void ChannelImpl::assert_in_channel() {
M
Megvii Engine Team 已提交
1351 1352 1353
    mgb_assert(
            get_worker_tid() != std::this_thread::get_id(),
            "this method cannot be called in worker thread");
1354 1355 1356
}

void ChannelImpl::assert_in_worker() {
M
Megvii Engine Team 已提交
1357 1358 1359
    mgb_assert(
            get_worker_tid() == std::this_thread::get_id(),
            "this method can only be called in worker thread");
1360 1361
}

1362
void ChannelImpl::sample_on_device(CompNode device, bool force) {
1363 1364 1365
    if (!Profiler::is_profiling()) {
        return;
    }
1366 1367
    if (!force) {
        thread_local int last_sample_id = 0;
1368
        int sample_rate = Profiler::get_option("sample_rate", 0);
1369 1370 1371 1372
        if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
            return;
        }
    }
1373
    MGB_RECORD_EVENT(SampleDeviceEvent, device);
1374
    auto [total, free] = device.get_mem_status_bytes();
1375
    MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
1376 1377
}

1378 1379 1380
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
    for (auto i : vec) {
        i->pin();
1381
        erase_candidate(i);
1382 1383 1384
    }
}

1385 1386
void ChannelImpl::DynamicSublinear::unpin(
        const SmallVector<TensorInfo*>& vec, WorkerState& state) {
1387 1388
    for (auto i : vec) {
        i->unpin();
1389 1390 1391 1392 1393
        if (i->pinned == 0 &&
            i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
            i->cand_index == UINT_MAX) {
            insert_candidate(i);
        }
1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439
    }
}

void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
    auto&& dsu_fa = find_father(ptr->dsu_ptr);
    dsu_fa->t -= ptr->compute_time;
    ptr->dsu_ptr->parent.reset();
    ptr->dsu_ptr->t = ptr->compute_time;
}

void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            merge(i->dsu_ptr, ptr->dsu_ptr);
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            merge(ptr->dsu_ptr, i->dsu_ptr);
        }
    }
}

double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
    double cost = 0;
    for (auto i : ptr->producer->inputs) {
        if (i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    for (auto i : ptr->producer->outputs) {
        if (i && i->evict_type == EvictType::DROP) {
            double t = find_father(i->dsu_ptr)->t;
            if (t < i->compute_time) {
                t = i->compute_time;
            }
            cost += t;
        }
    }
    return cost;
}

M
Megvii Engine Team 已提交
1440 1441
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
        bool enable_dtr_sqrt_sampling = false) {
1442 1443 1444
    if (candidates.empty())
        return nullptr;

1445 1446
    double min_msps = -1;
    TensorInfo* best = nullptr;
1447 1448
    size_t sz = 1;
    if (enable_dtr_sqrt_sampling) {
M
Megvii Engine Team 已提交
1449 1450
        while (sz * sz <= candidates.size())
            sz++;
1451
        sz--;
1452 1453 1454
    } else {
        sz = candidates.size();
    }
1455 1456 1457 1458 1459 1460 1461

    size_t ti = rand() % sz;
    for (size_t vi = 0; vi < sz; vi++) {
        if (!enable_dtr_sqrt_sampling) {
            ti = vi;
        }
        auto i = candidates[ti];
1462
        if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
1463
            double neighbor_cost = estimate_neighbor_cost(i);
M
Megvii Engine Team 已提交
1464 1465 1466 1467
            size_t begin_ptr =
                    reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
            auto side_info = i->ptr->comp_node().get_free_left_and_right(
                    begin_ptr, begin_ptr + i->ptr->blob()->size());
1468
            double free_mem = side_info.first + side_info.second;
M
Megvii Engine Team 已提交
1469 1470
            double msps = i->eval_func(
                    neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
1471 1472 1473 1474 1475
            if (min_msps < 0 || msps < min_msps) {
                min_msps = msps;
                best = i;
            }
        }
1476 1477 1478 1479 1480
        if (enable_dtr_sqrt_sampling) {
            ti += rand() % sz;
            if (ti > candidates.size())
                break;
        }
1481 1482 1483 1484
    }
    return best;
}

M
Megvii Engine Team 已提交
1485 1486
void ChannelImpl::DynamicSublinear::merge(
        std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
1487 1488 1489 1490 1491 1492 1493 1494 1495
    auto&& f_x = find_father(x);
    auto&& f_y = find_father(y);
    if (f_x.get() == f_y.get()) {
        return;
    }
    f_y->t += f_x->t;
    f_x->parent = f_y;
}

M
Megvii Engine Team 已提交
1496 1497
std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
        std::shared_ptr<DsuNode>& x) {
1498 1499 1500 1501 1502 1503 1504 1505 1506
    if (x->is_root()) {
        return x;
    } else {
        auto&& fa = find_father(x->parent);
        return x->parent = fa;
    }
}

void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
1507 1508 1509 1510 1511 1512
    // tensor to be inserted must be brand new
    mgb_assert(
            ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
            ptr->cand_index);
    ptr->cand_index = candidates.size();
    candidates.push_back(ptr);
1513 1514 1515 1516 1517 1518
    if (!comp_node.valid()) {
        comp_node = ptr->ptr->comp_node();
    }
}

void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
1519 1520 1521 1522 1523 1524
    // close dtr will just clear candidates, so nothing to erase
    if (candidates.empty()) {
        ptr->cand_index = UINT_MAX;
        return;
    }
    // some tensors may be erased already, just skip them
1525 1526 1527 1528 1529 1530
    if (ptr->cand_index != UINT_MAX) {
        std::swap(candidates[ptr->cand_index], candidates.back());
        candidates[ptr->cand_index]->cand_index = ptr->cand_index;
        candidates.pop_back();
        ptr->cand_index = UINT_MAX;
    }
1531 1532 1533 1534 1535
}

void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
    ptr->last_used_time = estimate_timestamp;
}