interpreter_impl.cpp 30.9 KB
Newer Older
M
Megvii Engine Team 已提交
1
/**
2
 * \file imperative/src/impl/interpreter/interpreter_impl.cpp
M
Megvii Engine Team 已提交
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "./interpreter_impl.h"
13

14
#include "megbrain/common.h"
15 16 17
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/autogen.h"
18 19 20
#include "megbrain/imperative/utils/to_string.h"

#include "../op_trait.h"
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;

std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
    return std::make_unique<ChannelImpl>();
}

Interpreter& Interpreter::inst() {
    static InterpreterImpl inst_;
    return inst_;
}

36
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
37 38 39 40
    auto info = alloc();
    info->desc.layout = value.layout();
    info->desc.comp_node = value.comp_node();
    info->desc.value = value.proxy_to_default_cpu();
41
    info->h_value = value;
42
    m_buffer.enqueue(Put{info, value, no_cache});
43 44 45 46
    if (m_async_level == 0) {
        sync();
        info->desc.comp_node.sync();
    }
47 48 49
    return info;
}

50
Handle ChannelImpl::put(const DeviceTensorND& data) {
M
Megvii Engine Team 已提交
51 52 53 54
    auto info = alloc();
    info->desc.layout = data.layout();
    info->desc.comp_node = data.comp_node();
    info->ptr = Tensor::make(data);
55 56 57
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
    }
M
Megvii Engine Team 已提交
58 59 60
    return info;
}

61
void ChannelImpl::del(Handle handle) {
62 63 64 65 66 67
    mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
    auto* info = reinterpret_cast<TensorInfo*>(handle);
    detach_users(info);
    info->detach_producer();
    m_valid_handle.erase(handle);
    m_buffer.enqueue(Del{info});
68 69
}

70
void ChannelImpl::swap_in(Handle handle) {
71
    if (m_worker_state.options.enable_swap) {
72 73
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
74 75 76
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapIn{info});
        info->evict_type = NONE;
77 78 79
    }
}

80
void ChannelImpl::swap_out(Handle handle) {
81
    if (m_worker_state.options.enable_swap) {
82 83
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
84 85 86
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        m_buffer.enqueue(SwapOut{info});
        info->evict_type = SWAP;
87 88 89
    }
}

90
void ChannelImpl::drop(Handle handle) {
91
    if (m_worker_state.options.enable_drop) {
92 93
        mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
                "invalid handle: %p", handle);
94 95 96 97 98 99 100
        auto* info = reinterpret_cast<TensorInfo*>(handle);
        if (!info->producer) {
            mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", info);
            return;
        }
        info->evict_type = DROP;
        m_buffer.enqueue(Drop{info});
101 102 103
    }
}

104
void ChannelImpl::dispatch_default_cpu(
105
        std::shared_ptr<OpDef> op,
106 107 108 109
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
110
    MGB_MARK_USED_VAR(validated);
111

112 113 114
    SmallVector<DeviceTensorND> input_tensornds;
    input_tensornds.reserve(input_descs.size());
    CompNode output_cn;
115 116
    {
        MGB_LOCK_GUARD(m_mutex);
117
        for (auto&& info : input_infos) {
118
            auto input_cn = info->desc.comp_node;
119
            if (!output_cn.valid()) {
120 121 122 123 124 125 126
                output_cn = input_cn;
            } else {
                mgb_assert(output_cn == input_cn, "cannot decide output comp node");
            }

            if (info->ptr && info->ptr->try_get_value()) {
                input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
127
            } else {
128 129
                mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
                input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
130 131 132 133 134 135 136 137 138 139 140 141 142 143
            }
        }
    }

    outputs->reserve(output_descs.size());
    SmallVector<DeviceTensorND> output_tensornds;
    output_tensornds.reserve(output_descs.size());
    for (auto&& desc : output_descs) {
        // TODO: may conflict with condtake, which need alloc inside
        mgb_assert(!desc.layout.is_empty());
        // use HostTensorND alloc_host for cuda pinned memory
        output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
    }

144 145 146 147 148 149 150 151
    auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
        SmallVector<uint64_t> tid;
        for (auto* ptinfo: tinfo) {
            tid.push_back(ptinfo->id);
        }
        return tid;
    };
    OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}};
152 153 154
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data);
    }
155

156 157 158 159 160 161 162
    OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);

    SmallVector<TensorInfo*> output_infos;
    output_infos.reserve(output_descs.size());
    for (auto&& tensornd : output_tensornds) {
        HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
            .proxy_to_comp_node(output_cn);
163 164 165
        // use `put` for consistency
        auto info = reinterpret_cast<TensorInfo*>(put(host_tensornd, false));
        mgb_assert(info->desc.layout.ndim != 0);
166 167 168
        output_infos.push_back(info);
        outputs->push_back(info);
    }
169 170

    if (m_channel_state.options.enable_drop) {
171
        TensorInfo::ComputePath::make(op, input_infos, output_infos);
172
    }
173 174

    event_data.outputs = tinfo_to_tid(output_infos);
175 176 177
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data);
    }
178
}
179

180 181 182 183 184
void ChannelImpl::dispatch_kernel(
        std::shared_ptr<OpDef> op,
        const SmallVector<TensorInfo*>& input_infos,
        const SmallVector<LogicalTensorDesc>& input_descs,
        SmallVector<Handle>* outputs) {
185
    auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
186

187
    ApplyOp cmd{std::move(op)};
188
    cmd.inputs = std::move(input_infos);
189
    cmd.outputs.reserve(output_descs.size());
190 191
    outputs->reserve(output_descs.size());
    for (auto&& desc : output_descs) {
192 193
        auto info = alloc();
        info->desc = desc;
194 195 196 197 198
        // make sure desc's value is consistent with h_value
        if (!info->desc.value.empty()) {
            info->h_value = HostTensorND::make_proxy(desc.value)
                .proxy_to_comp_node(desc.comp_node);
        }
199
        cmd.outputs.push_back(info);
200
        outputs->push_back(info);
201
    }
202
    if (m_channel_state.options.enable_drop) {
203
        TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
204
    }
205
    m_buffer.enqueue(std::move(cmd));
206
    if (!validated && m_channel_state.options.async_level == 1) {
207
        sync();
208
    } else if (m_channel_state.options.async_level == 0) {
209
        sync();
210
        // check device error
211
        for (auto&& oup : *outputs) {
212 213
            auto info = reinterpret_cast<TensorInfo*>(oup);
            info->ptr->comp_node().sync();
214
        }
215
    }
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
}

SmallVector<Handle> ChannelImpl::apply_op(
        std::shared_ptr<OpDef> op,
        const SmallVector<Handle>& inputs) {
    for (auto i : inputs) {
        mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
                "invalid handle: %p", i);
    }
    SmallVector<TensorInfo*> input_infos;
    input_infos.reserve(inputs.size());
    SmallVector<LogicalTensorDesc> input_descs;
    input_descs.reserve(inputs.size());
    {
        MGB_LOCK_GUARD(m_mutex);
        for (auto i : inputs) {
            auto info = reinterpret_cast<TensorInfo*>(i);
            mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
            input_infos.push_back(info);
            input_descs.push_back(info->desc);
236
            regenerate(info);
237 238 239 240
        }
    }

    SmallVector<Handle> outputs;
241 242 243 244
    DispatchMode dispatch_mode = m_channel_state.options.enable_host_compute
            ? OpDef::decide_dispatch_mode(*op, input_descs)
            : DispatchMode::KERNEL;
    switch (dispatch_mode) {
245 246 247 248 249 250 251 252 253
        case DEFAULT_CPU: {
            dispatch_default_cpu(op, input_infos, input_descs, &outputs);
            break;
        }
        case KERNEL: {
            dispatch_kernel(op, input_infos, input_descs, &outputs);
            break;
        }
    }
254 255 256
    return outputs;
}

257
HostTensorND ChannelImpl::get_value(Handle handle) {
258
    // TODO: maybe get_value should be done on host. i.e. delete GetValue
259 260 261 262
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    mgb_assert(!m_waitee);
263 264
    // donnot use info->value_fetched, it's unsafe
    mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
265
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
266 267 268 269 270
    TensorPtr tensor_ptr = info->ptr;
    auto value_fetched = [&]() {
        return tensor_ptr && tensor_ptr->value_fetched();
    };
    if (!value_fetched()) {
271
        m_waitee = info;
272
        regenerate(info);
273
        m_buffer.enqueue(GetValue{info});
274 275 276
        if (m_channel_state.profiler->is_profiling()) {
            m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
        }
277 278
        m_cv.wait(lock, [&]() {
            check_worker_exc_unsafe();
279 280
            tensor_ptr = info->ptr;
            return value_fetched();
281
        });
282 283 284
        if (m_channel_state.profiler->is_profiling()) {
            m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
        }
285 286
        m_waitee = nullptr;
    }
287
    return tensor_ptr->get_value();
288 289
}

290
TensorShape ChannelImpl::get_shape(Handle handle) {
291 292 293 294 295 296 297 298 299
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee);
    m_waitee = info;
300
    m_buffer.flush();
301 302 303
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
    }
304 305
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
306
        return static_cast<bool>(info->ptr);
307
    });
308 309 310
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
    }
311 312 313 314 315 316
    m_waitee = nullptr;
    TensorShape ret = info->ptr->layout();
    mgb_assert(ret.ndim != 0);
    return ret;
}

317
DType ChannelImpl::get_dtype(Handle handle) {
318 319 320
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
321 322 323
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
    }
324 325 326 327 328
    auto ret = info->desc.layout.dtype;
    mgb_assert(ret.valid());
    return ret;
}

329
CompNode ChannelImpl::get_device(Handle handle) {
330 331 332
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
333 334 335
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
    }
336 337 338 339 340
    auto ret = info->desc.comp_node;
    mgb_assert(ret.valid());
    return ret;
}

341
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
342 343 344 345 346 347
    mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
               "invalid handle: %p", handle);
    auto info = reinterpret_cast<TensorInfo*>(handle);
    std::unique_lock<decltype(m_mutex)> lock(m_mutex);
    mgb_assert(!m_waitee);
    m_waitee = info;
348
    regenerate(info);
349
    m_buffer.flush();
350 351 352
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
    }
353 354
    m_cv.wait(lock, [&]() {
        check_worker_exc_unsafe();
355
        return static_cast<bool>(info->ptr);
356
    });
357 358 359
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
    }
360 361 362 363 364
    m_waitee = nullptr;
    return info->ptr->dev_tensor();
}

void ChannelImpl::sync() {
365
    m_buffer.flush();
366 367 368
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<SyncStartEvent>();
    }
369
    m_worker.wait_all_task_finish();
370
    CompNode::sync_all();
371 372 373
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<SyncFinishEvent>();
    }
374 375 376 377 378 379 380 381
    MGB_LOCK_GUARD(m_mutex);
    check_worker_exc_unsafe();
}

void ChannelImpl::close() {
    sync();
}

382 383
int ChannelImpl::get_option(std::string name) {
    return m_channel_state.options.get_option(name);
384 385
}

386 387 388
void ChannelImpl::set_option(std::string name, int value) {
    m_channel_state.options.set_option(name, value);
    m_buffer.enqueue(SetOption{name, value});
389 390 391 392
}

TensorInfo* ChannelImpl::alloc() {
    MGB_LOCK_GUARD(m_mutex);
393
    auto info = m_pool.alloc();
394
    m_valid_handle.insert(info);
395
    info->id = m_last_id++;
396 397 398
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id);
    }
399
    return info;
400 401 402 403
}

void ChannelImpl::free(TensorInfo* ptr) {
    MGB_LOCK_GUARD(m_mutex);
404 405 406
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id);
    }
407 408 409
    m_pool.free(ptr);
}

410 411 412 413
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){
    m_channel_state.tid = std::this_thread::get_id();
}

414 415 416
ChannelImpl::~ChannelImpl() {
    close();
}
417

418 419
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
    MGB_LOCK_GUARD(m_mutex);
420 421 422
    if (m_worker_state.profiler->is_profiling()) {
        m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
    }
423 424 425 426 427
    dest->value_fetched = ptr->value_fetched();
    // update tensor desc for static infer
    dest->desc.layout = ptr->layout();
    dest->desc.comp_node = ptr->comp_node();
    dest->ptr = std::move(ptr);
428
    if (m_waitee == dest) {
429
        m_cv.notify_all();
430 431 432
    }
}

433 434 435 436 437
void ChannelImpl::release_tensor(TensorInfo* dest) {
    MGB_LOCK_GUARD(m_mutex);
    dest->ptr.reset();
}

438
void ChannelImpl::regenerate(TensorInfo* dest) {
439
    if (dest->evict_type == DROP) {
440 441 442
        recompute(dest->producer);
    } else if (dest->evict_type == SWAP) {
        swap_in(dest);
443
    }
444
    mgb_assert(dest->evict_type == NONE);
445 446
}

447 448 449 450
void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
    SmallVector<TensorInfo*> workspaces(path->outputs.size(), nullptr);
    for (auto&& input: path->inputs) {
        regenerate(input);
451
    }
452 453 454
    for (auto&& output: path->outputs) {
        if(output == nullptr) {
            continue;
455
        }
456
        output->evict_type = NONE;
457
    }
458
    m_buffer.enqueue(ApplyOp{path->op, path->inputs, path->outputs});
459 460
}

461 462 463 464 465 466 467 468 469 470
void ChannelImpl::detach_users(TensorInfo* dest) {
    SmallVector<TensorInfo::ComputePath*> users = dest->users;
    for (auto* user: users) {
        for (auto* output: user->outputs) {
            if (output == nullptr) {
                continue;
            }
            regenerate(output);
            output->detach_producer();
        }
471
    }
472 473
    mgb_assert(dest->users.size() == 0);
    //dest->users.clear();
474 475
}

476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
void ChannelImpl::sync_device_scope(CompNode device) {
    auto& prev = m_worker_state.device_scope_map[device];
    auto& current = m_worker_state.scopes;
    auto push_scope = [&](std::string name) {
        m_worker_state.profiler->record_device<DeviceBeginScope>(device, name);
    };
    auto pop_scope = [&](std::string name) {
        m_worker_state.profiler->record_device<DeviceEndScope>(device, name);
    };
    size_t similarity = 0;
    for (size_t i = 0; i < prev.size() && i < current.size(); i++) {
        if (prev[i] == current[i]) {
            similarity++;
        } else {
            break;
491 492
        }
    }
493 494 495
    while (prev.size() > similarity) {
        pop_scope(prev.back());
        prev.pop_back();
496
    }
497 498 499
    while (prev.size() < current.size()) {
        prev.push_back(current[prev.size()]);
        push_scope(prev.back());
500 501 502
    }
}

503
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
504 505 506
    if (m_worker_state.profiler->is_profiling()) {
        m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd);
    }
507 508 509 510 511
    bool finished = false;
    auto do_finish_command = [&]{
        if (finished) {
            return;
        }
512 513 514
        if (m_worker_state.profiler->is_profiling()) {
            m_worker_state.profiler->record_host<CommandFinishEvent>(icmd);
        }
515 516
        finished = true;
    };
517
    //TODO: remove std::visit for support osx 10.12
518 519
    auto cmd_visitor = [&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
520
            if constexpr (std::is_same_v<T, Put>) {
521 522
                auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
                produce_tensor(cmd.dest, std::move(value));
523
            } else if constexpr (std::is_same_v<T, ApplyOp>) {
524
                uint64_t apply_id = ++m_last_id;
525
                SmallVector<TensorPtr> tensor_inputs;
526
                SmallVector<CompNode> devices;
527
                tensor_inputs.reserve(cmd.inputs.size());
528
                // refcnt == 1, owners: [TensorInfo::ptr]
529
                for (auto i : cmd.inputs) {
530
                    mgb_assert(i->ptr, "Invalid input tensor ptr!");
531
                    // refcnt ++, owners: [i->ptr, tensor_inputs]
532 533
                    tensor_inputs.push_back(i->ptr);
                }
534
                // Begin profiling operator
535 536 537 538 539 540 541 542 543 544 545 546 547
                OpEvent event_data;
                if (m_worker_state.profiler->is_profiling()) {
                    auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
                        SmallVector<uint64_t> tid;
                        for (auto* ptinfo: tinfo) {
                            tid.push_back(ptinfo->id);
                        }
                        return tid;
                    };
                    event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)};
                    // Collecting devices
                    for (auto i : cmd.inputs) {
                        devices.push_back(i->desc.comp_node);
548
                    }
549 550 551 552
                    for (auto i : cmd.outputs) {
                        devices.push_back(i->desc.comp_node);
                    }
                    devices.erase(std::unique(devices.begin(), devices.end()), devices.end());
553
                }
554 555 556 557 558 559 560 561
                // Fused by command buffer. @see: CommandBuffer::fuse_del
                // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
                // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
                for (auto* del : cmd.dels) {
                    // refcnt --, owners: [tensor_inputs]
                    // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
                    free(del);
                }
562 563 564
                // Before wait
                //TODO: split operator wait and execute so that OpWait could be corrected recorded.
                // Before execute
565 566 567 568 569 570
                if (m_worker_state.profiler->is_profiling()) {
                    m_worker_state.profiler->record_host<HostOpExecuteEvent>(event_data);
                    for (auto&& device: devices) {
                        sync_device_scope(device);
                        m_worker_state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data);
                    }
571 572
                }
                // Apply op
573 574 575
                // Here std::move is REQUIRED for removing duplicated references.
                auto tensor_outputs = OpDef::apply_on_physical_tensor(
                    *cmd.op, std::move(tensor_inputs));
576
                // After execute
577 578 579 580 581
                if (m_worker_state.profiler->is_profiling()) {
                    m_worker_state.profiler->record_host<HostOpFinishEvent>(event_data);
                    for (auto&& device: devices) {
                        m_worker_state.profiler->record_device<DeviceOpFinishEvent>(device, event_data);
                    }
582 583
                }
                // End profiling operator
584 585
                mgb_assert(tensor_outputs.size() == cmd.outputs.size());
                for (size_t i = 0; i < tensor_outputs.size(); ++i) {
586 587 588
                    if (cmd.outputs[i] == nullptr) {
                        continue;
                    }
589 590 591 592 593
                    produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
                }
            } else if constexpr (std::is_same_v<T, Del>) {
                free(cmd.dest);
            } else if constexpr (std::is_same_v<T, GetValue>) {
594
                mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
595 596 597 598 599 600
                cmd.dest->ptr->fetch_value();
                MGB_LOCK_GUARD(m_mutex);
                cmd.dest->value_fetched = true;
                if (m_waitee == cmd.dest) {
                    m_cv.notify_all();
                }
601
            } else if constexpr (std::is_same_v<T, SwapIn>) {
602
                produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
603
            } else if constexpr (std::is_same_v<T, SwapOut>) {
604
                cmd.dest->h_value = cmd.dest->ptr->get_value();
605
                release_tensor(cmd.dest);
606
            } else if constexpr (std::is_same_v<T, Drop>) {
607
                release_tensor(cmd.dest);
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640
            } else if constexpr (std::is_same_v<T, SetOption>) {
                m_worker_state.options.set_option(cmd.key, cmd.value);
            } else if constexpr (std::is_same_v<T, StartProfile>) {
                CompNode::sync_all();
                m_worker_state.profiler.reset(cmd.profiler);
            } else if constexpr (std::is_same_v<T, StopProfile>) {
                for (auto&& [device, scopes]: m_worker_state.device_scope_map) {
                    MGB_MARK_USED_VAR(scopes);
                    sync_device_scope(device);
                }
                do_finish_command();
                auto profiler = std::make_unique<InterpreterProfiler>();
                std::swap(profiler, m_worker_state.profiler);
                auto records = profiler->stop();
                auto host_map = [this](std::thread::id tid) {
                    if (tid == m_channel_state.tid) {
                        return "channel";
                    } else if (tid == m_worker_state.tid) {
                        return "worker";
                    } else {
                        return "unknown";
                    }
                };
                InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map);
            } else if constexpr (std::is_same_v<T, PushScope>) {
                m_worker_state.scopes.push_back(cmd.scope_name);
                do_finish_command();
                m_worker_state.profiler->record_host<WorkerBeginScope>(cmd.scope_name);
            } else if constexpr (std::is_same_v<T, PopScope>) {
                mgb_assert(m_worker_state.scopes.back() == cmd.scope_name, "scope name mismatch");
                m_worker_state.scopes.pop_back();
                do_finish_command();
                m_worker_state.profiler->record_host<WorkerEndScope>(cmd.scope_name);
641
            } else {
642
                static_assert(!std::is_same_v<T, T>);
643
            }
644
    };
645
    std::visit([&](const auto& cmd){
646 647 648 649 650 651 652
        using T = std::decay_t<decltype(cmd)>;
        if (!m_worker_state.options.catch_worker_execption) {
            cmd_visitor(cmd);
            return;
        }
        try {
            cmd_visitor(cmd);
653 654
        } catch (...) {
            MGB_LOCK_GUARD(m_mutex);
655 656 657 658 659 660 661
            if constexpr (std::is_same_v<T, ApplyOp>) {
                for (auto oup : cmd.outputs) {
                    oup->invalid = true;
                }
            } else if constexpr (std::is_same_v<T, Put>) {
                cmd.dest->invalid = true;
            }
662 663 664
            m_worker_exc = std::current_exception();
            m_cv.notify_all();
        }
665 666
    }, icmd.second);
    do_finish_command();
667 668 669 670
}

void ChannelImpl::check_worker_exc_unsafe() {
    if (m_worker_exc) {
671 672
        // for reuse interpreter_for_py after some exception tests
        m_waitee = nullptr;
673 674 675 676 677
        std::exception_ptr exc;
        std::swap(exc, m_worker_exc);
        std::rethrow_exception(exc);
    }
}
678 679 680 681 682

void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
    if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
        return;
    }
683
    // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
684 685 686 687 688
    m_commands.push_back(std::move(cmd));
    auto flush_pos = flush_pos_for(m_commands.back());
    flush(flush_pos);
}

689 690 691 692
void ChannelImpl::CommandBuffer::flush() {
    flush(m_commands.end());
}

693 694
void ChannelImpl::CommandBuffer::flush(Handle pos) {
    for (auto iter = m_commands.begin(); iter != pos; ++iter) {
695
        // mgb_log_debug("%s Flushed", to_string(*iter).c_str());
696
        IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)};
697 698 699
        if (m_owner->m_channel_state.profiler->is_profiling()) {
            m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd);
        }
700
        m_owner->m_worker.add_task(std::move(icmd));
701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720
    }
    m_commands.erase(m_commands.begin(), pos);
}

auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
    return std::visit([this](const auto& cmd) {
        using T = std::decay_t<decltype(cmd)>;
        if constexpr (std::is_same_v<T, ApplyOp>) {
            auto* op_type = cmd.op->dyn_typeinfo();
            if (op_type == RemoteRecv::typeinfo() ||
                op_type == RemoteSend::typeinfo() ||
                op_type == CollectiveComm::typeinfo() ||
                op_type == opr::InputCallback::typeinfo() ||
                op_type == opr::OutputCallback::typeinfo() ||
                op_type == BackwardGraph::typeinfo()) {
                return m_commands.end();
            }
        } else if constexpr (std::is_same_v<T, GetValue>) {
            return m_commands.end();
        }
721 722 723
        size_t buffer_length = m_owner->m_channel_state.options.buffer_length;
        if (m_commands.size() > buffer_length) {
            return m_commands.begin() + (m_commands.size() - buffer_length);
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746
        }
        return m_commands.begin();
    }, cmd);
}

/**
 * 1. Find ApplyOp(dest) in buffered commands
 * 2. Check if there are other usages between ApplyOp and Del, return false if not
 * 3. Fuse Del into ApplyOp, return true
 */
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
    auto* dest = cmd.dest;
    // TODO: eliminate Puts
    auto begin = m_commands.begin(), end = m_commands.end();
    auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
        if (auto* apply = std::get_if<ApplyOp>(&cmd)) {
            return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
        }
        return false;
    });
    if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
        return false;
    }
747
    // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793
    std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
    return true;
}

auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
        -> Handle {
    auto found = range[1];
    for (auto iter = range[0]; iter != range[1]; ++iter) {
        std::visit([&](const auto& cmd) {
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
                               dest) > 0) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, GetValue>) {
                if (cmd.dest == dest) {
                    found = iter;
                }
            } else if constexpr (std::is_same_v<T, SwapIn> ||
                    std::is_same_v<T, SwapOut> ||
                    std::is_same_v<T, Drop>) {
                //TODO: ignore swap-like commands, just remove them from buffer
                if (cmd.dest == dest) {
                    found = iter;
                }
            }
        }, *iter);
    };
    return found;
}

auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
        -> Handle {
    return std::find_if(range[0], range[1], [dest](auto& cmd) {
        return std::visit([dest](const auto& cmd){
            using T = std::decay_t<decltype(cmd)>;
            if constexpr (std::is_same_v<T, ApplyOp>) {
                return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
            } else if constexpr (std::is_same_v<T, Put>) {
                return cmd.dest == dest;
            }
            return false;
        }, cmd);
    });
}
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812

void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
    auto profiler_option = InterpreterProfiler::Option::from_dict(option);
    auto profiler = std::make_unique<InterpreterProfiler>();
    profiler->set_option(profiler_option);
    profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic));
    std::swap(profiler, m_channel_state.profiler);
    m_buffer.enqueue(StartProfile{m_channel_state.profiler.get()});
}

void ChannelImpl::stop_profile(std::string basename, std::string format) {
    m_buffer.flush();
    auto profiler = std::make_unique<InterpreterProfiler>();
    std::swap(profiler, m_channel_state.profiler);
    profiler.release();
    m_buffer.enqueue(StopProfile{basename, format});
}

void ChannelImpl::push_scope(std::string name) {
813 814 815 816 817
    if (m_channel_state.profiler->is_profiling()) {
        m_channel_state.profiler->record_host<ChannelBeginScope>(name);
        m_channel_state.scopes.push_back(name);
        m_buffer.enqueue(PushScope{name});
    }
818 819 820
}

void ChannelImpl::pop_scope(std::string name) {
821 822 823 824 825 826
    if (m_channel_state.profiler->is_profiling()) {
        mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
        m_channel_state.scopes.pop_back();
        m_channel_state.profiler->record_host<ChannelEndScope>(name);
        m_buffer.enqueue(PopScope{name});
    }
827 828 829 830 831 832 833 834 835
}

void ChannelImpl::assert_in_channel() {
    mgb_assert(m_channel_state.tid != std::this_thread::get_id());
}

void ChannelImpl::assert_in_worker() {
    mgb_assert(m_worker_state.tid == std::this_thread::get_id());
}