graph_rt.cpp 27.5 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4
/**
 * \file imperative/python/src/graph_rt.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
M
Megvii Engine Team 已提交
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13
#include "./graph_rt.h"

14
#include "megbrain/graph/cg.h"
15
#include "megbrain/serialization/serializer.h"
16
#include "megbrain/imperative/opr_utility.h"
M
Megvii Engine Team 已提交
17
#include "megbrain/opr/io.h"
18
#include "megbrain/opr/utility.h"
19 20 21
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative.h"
#include "./helper.h"
22
#include "megbrain/plugin/profiler.h"
23
#include "./common.h"
24
#include "./ops.h"
25
#include "megbrain/gopt/inference.h"
26
#include "megbrain/imperative/profiler_plugin.h"
27 28 29 30 31

namespace py = pybind11;

using namespace mgb;
using namespace imperative;
32
namespace ser = mgb::serialization;
33

34 35
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
36
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
37
using _SerializationMetadata = mgb::serialization::Metadata;
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
namespace {
class _CompGraphProfilerImpl {
    std::shared_ptr<ComputingGraph> m_comp_graph;
    GraphProfiler m_profiler;
    public:
        _CompGraphProfilerImpl(std::shared_ptr<ComputingGraph> cg):
            m_comp_graph{cg},
            m_profiler{m_comp_graph.get()}
        {
        }

        std::string _get_result() {
            auto json = m_profiler.to_json_full(
                    m_comp_graph->current_comp_seq());
            return json->to_string();
        }
};
56 57 58 59 60 61 62

struct WeakRendezvousArray:
    public std::vector<std::weak_ptr<RendezvousBase>>,
    public UserDataContainer::UserData {
    MGB_TYPEINFO_OBJ_DECL;
};
MGB_TYPEINFO_OBJ_IMPL(WeakRendezvousArray);
63
}
64 65 66 67 68
#define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name)

template<typename T>
auto def_rendezvous(py::object m, const char* name) {
    return py::class_<Rendezvous<T>, std::shared_ptr<Rendezvous<T>>>(m, name)
69
        .def(py::init([](){return Rendezvous<T>::make();}))
70 71
        .def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));})
        .def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>())
M
Megvii Engine Team 已提交
72
        .def("drop", &Rendezvous<T>::drop)
73 74 75 76 77
        .def("reset", &Rendezvous<T>::reset)
        .def("set_exception", [](Rendezvous<T>& r, std::string&& message) {
            r.set_exception(std::make_exception_ptr(
                    std::runtime_error(std::move(message))));
        });
78 79 80
}

using TensorAttr = LogicalTensorDesc;
M
Megvii Engine Team 已提交
81
using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>;
82

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
std::vector<mgb::cg::VarNode*>  _replace_vars(const std::vector<mgb::cg::VarNode*>& repl_src,
                                 const std::vector<mgb::cg::VarNode*>& repl_dst,
                                 const std::vector<mgb::cg::VarNode*>& vars) {
        mgb::ThinHashMap<SymbolVar, SymbolVar> varmap;
        for (size_t i = 0; i < repl_src.size(); ++i) {
            varmap[SymbolVar(repl_src[i])] = SymbolVar(repl_dst[i]);
        }
        SymbolVarArray symvars(vars.begin(), vars.end());
        auto sym_result = mgb::cg::replace_vars(symvars, varmap);
        std::vector<mgb::cg::VarNode*> result;
        for (auto symvar : sym_result){
            result.push_back(symvar.node());
        }
        return result;
    }

typedef std::vector<mgb::cg::OperatorNodeBase*> OperatorArray;
std::vector<mgb::cg::VarNode*> _replace_oprs(const OperatorArray& repl_src,
                                 const OperatorArray& repl_dst,
                                 const std::vector<mgb::cg::VarNode*>& vars) {
        mgb::ThinHashMap<mgb::cg::OperatorNodeBase*, mgb::cg::OperatorNodeBase*>
                oprmap;
        for (size_t i = 0; i < repl_src.size(); ++i) {
            oprmap[repl_src[i]] = repl_dst[i];
        }
        const SymbolVarArray symvars(vars.begin(), vars.end());
        auto sym_result = mgb::cg::replace_oprs(symvars, oprmap);
        std::vector<mgb::cg::VarNode*> result;
        for (auto symvar : sym_result){
            result.push_back(symvar.node());
        }
        return result;
    }



void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
        auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
            if (opr->node_prop().attribute().priority == 0) {
                opr->node_prop().attribute().priority = opr->id();
            }
        };
        mgb::cg::DepOprIter dep_iter{on_opr};
        for (const auto& var : dest_vars) {
            dep_iter.add(SymbolVar(var));
        }
}



133
void init_graph_rt(py::module m) {
134 135 136

   static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{std::make_unique<mgb::OprFootprint>()};

137 138
    def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");

M
Megvii Engine Team 已提交
139 140
    def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");

141 142 143 144 145
    def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous");

    py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode")
        .def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();})
        .def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();})
146 147
        .def_property("name", py::overload_cast<>(&VarNode::name, py::const_),
                      py::overload_cast<std::string>(&VarNode::name))
148
        .def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();})
M
Megvii Engine Team 已提交
149 150 151 152
        .def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();})
        .def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* {
                auto&& mgr = v->owner_graph()->static_infer_manager();
                return mgr.infer_shape_fallible(v);
153 154 155 156 157 158 159 160 161 162 163 164 165
            })
        .def_property_readonly("value", [](cg::VarNode* v) -> py::object {
                auto&& mgr = v->owner_graph()->static_infer_manager();
                auto&& type = mgr.get_infer_type(v);
                using InferType = cg::static_infer::InferType;
                if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
                    return py::none();
                }
                auto* val = mgr.infer_value_fallible(v);
                if (!val) {
                    return py::none();
                }
                return py::cast(*val).attr("numpy")();
166 167 168
            })
        .def_property_readonly("id",[](cg::VarNode* v){
            return (v->id());
169 170 171
        })
        .def("__repr__", [](cg::VarNode* v) {
            return "Var:" + v->name();
172
        });
173 174 175

    py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
        .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();})
176 177
        .def_property("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_),
                      py::overload_cast<std::string>(&cg::OperatorNodeBase::name))
178 179 180 181
        .def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) {
                return to_tuple(opr->input());
            })
        .def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) {
M
Megvii Engine Team 已提交
182
                return to_tuple(opr->usable_output());
183 184 185 186 187 188 189 190 191
            })
        .def_property_readonly("id",[](cg::OperatorNodeBase* opr){
            return opr->id();
        })
        .def_property_readonly("params",[](cg::OperatorNodeBase* opr){
            return _imperative_sm_opr_footprint_ptr->calc_footprint(opr).param->to_string();
        })
        .def_property_readonly("type",[](cg::OperatorNodeBase* opr){
            return opr->dyn_typeinfo()->name;
192 193 194
        })
        .def("__repr__", [](cg::OperatorNodeBase* opr){
            return "Opr:" + opr->name();
195 196
        });

197 198
    py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
        .def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>())
199
        .def("wait", &cg::AsyncExecutable::wait, py::call_guard<py::gil_scoped_release>())
200
        .def("get_prev_exec_time", &cg::AsyncExecutable::get_prev_exec_time, py::call_guard<py::gil_scoped_release>())
201 202 203 204 205
        .def("_to_json", [](cg::AsyncExecutable* exec) {
            py::call_guard<py::gil_scoped_release>();
            // dump currently compiled computing graph for debugging
            return exec->to_json()->to_string();
        })
206 207 208 209 210 211 212 213 214 215 216 217 218
        // only used for exception handle
        .def_property_readonly("_all_rendezvous", [](cg::AsyncExecutable* exec) {
            auto ud = exec->owner_graph()->options().user_data
                        .get_user_data<WeakRendezvousArray>();
            std::vector<std::shared_ptr<RendezvousBase>> ret;
            if (ud.second) {
                for (auto&& r: *ud.first[0]) {
                    if (auto p = r.lock()) {
                        ret.emplace_back(std::move(p));
                    }
                }
            }
            return ret;
219 220 221 222
        })
        .def("get_static_memory_alloc_info",
             &cg::AsyncExecutable::get_static_memory_alloc_info,
             py::call_guard<py::gil_scoped_release>());
223 224 225 226 227 228 229 230 231 232 233 234 235

    auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph")
        .def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))
        .def("compile", [](cg::ComputingGraph& graph, const std::vector<cg::VarNode*>& dest_vars) {
                mgb_assert(!dest_vars.empty());
                cg::ComputingGraph::OutputSpec spec;
                for (auto v : dest_vars) {
                    spec.emplace_back(v, nullptr);
                }
                return graph.compile(spec);
            })
        .def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options));

236 237 238 239 240 241
    py::class_<_CompGraphProfilerImpl, std::shared_ptr<_CompGraphProfilerImpl>>(m, "GraphProfiler")
        .def(py::init([](std::shared_ptr<ComputingGraph> graph) {
                return std::make_shared<_CompGraphProfilerImpl>(graph);
                }))
        .def("get", [](_CompGraphProfilerImpl& profiler) { return profiler._get_result(); });

242 243 244 245
    using interpreter::intl::ProfilerPlugin;
    py::class_<ProfilerPlugin, std::shared_ptr<ProfilerPlugin>>(m, "GraphProfiler2")
        .def(py::init<cg::ComputingGraph*>());

246 247
    auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
        .def(py::init())
248 249
        .def("serialize", &_OptimizeForInferenceOptions::serialize)
        .def_static("deserialize", &_OptimizeForInferenceOptions::deserialize)
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
        .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp)
        .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
        .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
        .def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z)
        .def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform)
        ;

    py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform")
        .value("DEFAULT", _LayoutTransform::DEFAULT)
        .value("NCHW4", _LayoutTransform::NCHW4)
        .value("NHWCD4", _LayoutTransform::NHWCD4)
        .value("NCHW88", _LayoutTransform::NCHW88)
        .value("NCHW44", _LayoutTransform::NCHW44)
        .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
        .value("NCHW32", _LayoutTransform::NCHW32)
        .value("CHWN4", _LayoutTransform::CHWN4)
266
        .value("NCHW64", _LayoutTransform::NCHW64)
267 268 269 270 271 272 273 274 275 276 277 278
        .export_values()
        ;

    m.def("optimize_for_inference", [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) {
        SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
        auto res_symvars = mgb::gopt::optimize_for_inference(symvars, opt);
        VarNodeArray vars;
        for (auto& si: res_symvars)
            vars.push_back(si.node());
        return vars;
    });

279
    m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars,
280 281
                                                 const _AlgoStrategy& strategy) {
        mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, strategy);
282 283
    });

284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
    m.def("get_info_for_strip", [](const std::vector<VarNode*>& dest_vars) {
        std::unordered_set<const char*> opr_types, dtype_names, elemwise_modes;
        auto on_opr = [&](cg::OperatorNodeBase *opr) {
            if (ser::GraphDumper::should_remove_in_dump(opr))
                return;
            opr_types.insert(opr->dyn_typeinfo()->name);
            for (auto i : opr->output())
                dtype_names.insert(i->dtype().name());
            if (opr->same_type<opr::Elemwise>()) {
                auto mode = opr->cast_final<opr::Elemwise>().param().mode;
                elemwise_modes.insert(
                        megdnn::Elemwise::ModeTrait::from_mode(mode).name);
            }
        };
        cg::DepOprIter opr_iter{on_opr};
        for (auto i : dest_vars)
            opr_iter.add(i->owner_opr());

        auto to_json = [](const std::unordered_set<const char*> &v) {
            std::vector<std::string> vs(v.begin(), v.end());
            std::sort(vs.begin(), vs.end());
            auto ret = json::Array::make();
            for (auto &&i : vs)
                ret->add(json::String::make(i));
            return ret;
        };

        return json::Object::make({
            {"opr_types", to_json(opr_types)},
            {"dtypes", to_json(dtype_names)},
            {"elemwise_modes", to_json(elemwise_modes)},
315
        })->to_string();
316 317
    });

318 319 320 321 322 323 324 325 326 327 328
    py::class_<_SerializationMetadata>(m, "SerializationMetadata")
        .def(py::init())
        .def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); },
            &_SerializationMetadata::set_user_info)
        .def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference)
        .def_property("optimize_options", &_SerializationMetadata::get_optimize_options,
            &_SerializationMetadata::set_optimize_options)
        .def_readwrite("graph_modified", &_SerializationMetadata::graph_modified)
        .def_readwrite("is_valid", &_SerializationMetadata::is_valid)
        ;

329 330 331
    m.def("dump_graph", [](
        const std::vector<VarNode*>& dest_vars,
        int keep_var_name,
332
        bool keep_opr_name,
333 334
        bool keep_param_name,
        bool keep_opr_priority,
335
        std::optional<_SerializationMetadata> metadata,
336 337 338 339 340 341 342 343 344 345
        py::list& stat,
        py::list& inputs,
        py::list& outputs,
        py::list& params
    ) {
        std::vector<uint8_t> buf;
        auto dumper = ser::GraphDumper::make(ser::OutputFile::make_vector_proxy(&buf));
        SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());

        ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
346
                                       keep_opr_priority, keep_opr_name};
347

348 349 350 351 352 353
        ser::GraphDumper::DumpResult rst;
        if (metadata)
            rst = dumper->dump(symvars, config, *metadata);
        else
            rst = dumper->dump(symvars, config);

354 355 356 357 358 359 360 361 362
        for (auto i : rst.inputs) {
            inputs.append(py::cast(i));
        }
        for (auto i : rst.outputs) {
            outputs.append(py::cast(i));
        }
        for (auto i : rst.params) {
            params.append(py::cast(i));
        }
363 364 365
        auto rst_stat =
                std::vector{rst.nr_opr, rst.tot_bytes, rst.tensor_value_bytes,
                            static_cast<size_t>(rst.content_hash)};
366 367 368 369 370
        for (auto i : rst_stat) {
            stat.append(py::cast(i));
        }
        return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size());
    });
371

372 373 374 375 376 377 378 379 380
    m.def("load_graph", [](
        std::string& buf,
        py::list& output_var_map,
        py::list& output_var_list
    ) {
        auto file = ser::InputFile::make_mem_proxy(buf.c_str(), buf.length());
        auto format = ser::GraphLoader::identify_graph_dump_format(*file);
        auto loader = ser::GraphLoader::make(std::move(file), format.val());
        ser::GraphLoader::LoadConfig config;
381
        auto rst = loader->load(config);
382 383
        for (auto i : rst.output_var_map) {
            output_var_map.append(py::make_tuple(i.first, i.second.node()));
384
        }
385 386
        for (auto i : rst.output_var_list) {
            output_var_list.append(i.node());
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
        }
        std::unordered_map<HostTensorND*, const std::string*> tensor2name;
        for (const auto& pair : rst.tensor_map) {
            tensor2name[pair.second.get()] = &pair.first;
        }
        auto cb = [&tensor2name, graph=rst.graph](cg::OperatorNodeBase* opr) {
            if (!opr->same_type<opr::Host2DeviceCopy>())
                return;
            auto& h2d = opr->cast_final_safe<opr::Host2DeviceCopy>();
            auto it = tensor2name.find(h2d.host_data().get());
            mgb_throw_if(it == tensor2name.end(), GraphError,
                        "unbound Host2DeviceCopy in loaded graph");
            h2d.output(0)->name(*it->second);
        };
        cg::DepOprIter iter{cb};
402 403
        for (const auto& var : rst.output_var_list) {
            iter.add(var);
404
        }
405 406 407 408
        auto ret = py::tuple(2);
        ret[0] = py::cast(rst.graph);
        ret[1] = py::cast(rst.metadata);
        return ret;
409 410
    });

411 412 413 414 415 416 417 418 419 420 421 422 423 424
#define CURRENT_CLASS cg::ComputingGraph::Options

    auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
        // DEF_READWRITE(opr_attribute)
        DEF_READWRITE(seq_opt)
        DEF_READWRITE(graph_opt)
        DEF_READWRITE(graph_opt_level)
        DEF_READWRITE(log_level)
        DEF_READWRITE(async_exec_level)
        DEF_READWRITE(force_dynamic_alloc)
        DEF_READWRITE(var_sanity_check_first_run)
        DEF_READWRITE(allocate_static_mem_after_graph_compile)
        DEF_READWRITE(fake_next_exec)
        DEF_READWRITE(enable_sublinear_memory_opt)
425
        DEF_READWRITE(enable_dtr_memory_opt)
426 427 428 429 430
        DEF_READWRITE(no_profiling_on_shape_change)
        DEF_READWRITE(enable_var_mem_defragment)
        DEF_READWRITE(enable_grad_var_static_reshape)
        DEF_READWRITE(enable_memory_swap)
        DEF_READWRITE(comp_node_seq_record_level)
431
        DEF_READWRITE(no_force_inplace)
432
        DEF_READWRITE(sublinear_mem_config)
433
        DEF_READWRITE(dtr_config)
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
        // DEF_READWRITE(eager_evaluation)
        // DEF_READWRITE(imperative_proxy_graph)
        // DEF_READWRITE(extra_vardeps)
        // DEF_READWRITE(user_data)
        ;

#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::SeqOpt

    py::class_<cg::ComputingGraph::Options::SeqOpt>(PyComputingGraphOptions, "SeqOpt")
        DEF_READWRITE(enable_mem_plan_opt)
        DEF_READWRITE(enable_mem_reuse_alloc)
        DEF_READWRITE(enable_seq_comp_node_opt);

#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt

451 452
    auto PyGraphOpt = py::class_<cg::ComputingGraph::Options::GraphOpt>(
            PyComputingGraphOptions, "GraphOpt")
453
        DEF_READWRITE(jit)
454
        DEF_READWRITE(jit_config)
455 456 457
        DEF_READWRITE(tensorrt);

#undef CURRENT_CLASS
458
#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt::JITConfig
459

460 461 462 463 464
    py::class_<cg::ComputingGraph::Options::GraphOpt::JITConfig>(PyGraphOpt, "JITConfig")
        DEF_READWRITE(fuse_dimshuffle)
        DEF_READWRITE(fuse_reduce);

#undef CURRENT_CLASS
465 466 467 468 469 470 471 472 473
#define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig

    py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig")
        DEF_READWRITE(thresh_nr_try)
        DEF_READWRITE(genetic_nr_iter)
        DEF_READWRITE(genetic_pool_size)
        DEF_READWRITE(lb_memory)
        DEF_READWRITE(num_worker);

474 475 476 477 478 479 480 481
#undef CURRENT_CLASS

#define CURRENT_CLASS cg::ComputingGraph::Options::DTRConfig

    py::class_<cg::ComputingGraph::Options::DTRConfig>(PyComputingGraphOptions, "DTRConfig")
        DEF_READWRITE(eviction_threshold)
        DEF_READWRITE(evictee_minimum_size);

482
#undef CURRENT_CLASS
483 484 485 486
    auto common = rel_import("common", m, 1);

    common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) {
            cg::VarNodeArray vinputs(inputs.begin(), inputs.end());
487
            return to_tuple(OpDef::apply_on_var_node(def, vinputs));
488 489 490 491 492 493
        },
        py::arg(), py::arg(), py::arg("graph") = py::none());

    auto input_callback = [](auto callback,
                             const CompNode& comp_node,
                             const DType& dtype,
M
Megvii Engine Team 已提交
494
                             const TensorShape& shape,
495
                             const std::vector<cg::VarNode*>& inputs,
496 497
                             cg::ComputingGraph* graph,
                             bool use_static_shape) {
498 499 500 501 502 503 504 505
        if (!graph) {
            graph = inputs[0]->owner_graph();
        }
        SymbolVarArray sinputs;
        for (auto i : inputs) {
            sinputs.emplace_back(i);
        }
        static_assert(!std::is_reference<decltype(callback)>::value);
506 507 508
        auto soutputs = opr::InputCallback::make(*graph, std::move(callback),
                                                 comp_node, dtype, shape,
                                                 sinputs, use_static_shape);
509 510 511 512 513 514 515 516
        std::vector<VarNode*> outputs;
        outputs.reserve(soutputs.size());
        for (auto i : soutputs) {
            outputs.push_back(i.node());
        }
        return outputs;
    };

M
Megvii Engine Team 已提交
517 518 519 520
    m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) {
            return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node();
        });

521
    m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype, std::optional<std::string> name) {
M
Megvii Engine Team 已提交
522
            if (!cn.valid()) {
523
                cn = CompNode::load(get_default_device());
M
Megvii Engine Team 已提交
524
            }
525 526 527 528
            OperatorNodeConfig config(cn);
            if (name) {
                config.name(*name);
            }
M
Megvii Engine Team 已提交
529
            auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
530 531
            return opr::ImmutableTensor::make(*graph, hv, config).node();
        }, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none());
M
Megvii Engine Team 已提交
532

533
    m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) {
534 535 536 537 538 539 540 541 542 543
            if (!cn.valid()) {
                throw py::type_error("device must be valid");
            }
            if (!dtype.valid()) {
                throw py::type_error("dtype must be valid");
            }
            OperatorNodeConfig config;
            if (name) {
                config.name(*name);
            }
544 545
            return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node();
        }, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none());
546

547 548 549 550
    m.def("_replace_vars", &_replace_vars,py::arg(),py::arg(),py::arg());
    m.def("_replace_oprs", &_replace_oprs,py::arg(),py::arg(),py::arg());
    m.def("_set_priority_to_id",&_set_priority_to_id,py::arg());

551 552 553
    m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
                                             const CompNode& comp_node,
                                             const DType& dtype,
M
Megvii Engine Team 已提交
554
                                             const TensorShape& shape,
555
                                             const std::vector<cg::VarNode*>& inputs,
556 557 558 559 560
                                             cg::ComputingGraph* graph,
                                             bool use_static_shape) {
            return input_callback(
                [f=std::move(callback)](){py::gil_scoped_acquire _; return f();},
                comp_node, dtype, shape, inputs, graph, use_static_shape);
561
        },
562 563
        py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(),
        py::arg("graph") = py::none(), py::arg("use_static_shape") = false);
564 565 566 567

    m.def("input_callback", [input_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p,
                                             const CompNode& comp_node,
                                             const DType& dtype,
M
Megvii Engine Team 已提交
568
                                             const TensorShape& shape,
569
                                             const std::vector<cg::VarNode*>& inputs,
570 571
                                             cg::ComputingGraph* graph,
                                             bool use_static_shape) {
572 573 574
            auto f = [p]() -> DeviceTensorND {
                return p->get();
            };
575
            return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph, use_static_shape);
576
        },
577 578
        py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), 
        py::arg("graph") = py::none(), py::arg("use_static_shape") = false);
579

580 581 582 583 584 585 586 587
    auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs,
            std::shared_ptr<RendezvousBase> r = {}, bool borrow = false, bool prefer_host_value = false) {
        if (r) {
            mgb_assert(inputs.size());
            auto cg = inputs[0]->owner_graph();
            cg->options().user_data.get_user_data_or_create<WeakRendezvousArray>()
                    ->emplace_back(r);
        }
588 589 590 591 592
        SymbolVarArray sinputs;
        for (auto i : inputs) {
            sinputs.emplace_back(i);
        }
        static_assert(!std::is_reference<decltype(callback)>::value);
593
        opr::OutputCallback::Param param{std::move(callback), borrow, prefer_host_value};
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
        auto output = opr::OutputCallback::make(std::move(param), sinputs);
        return output.node();
    };

    m.def("output_callback", [output_callback](std::function<void(DeviceTensorND)> callback, std::vector<cg::VarNode*> inputs) {
        auto f = [f=std::move(callback)](DeviceTensorND dv) {
            auto task = [f=std::move(f), dv=std::move(dv)]() {
                f(dv);
            };
            py_task_q.add_task(std::move(task));
        };
        return output_callback(std::move(f), std::move(inputs));
    });

    m.def("output_callback", [output_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p, std::vector<cg::VarNode*> inputs) {
        auto f = [p](DeviceTensorND dv) {
            p->set(std::move(dv));
        };
612
        return output_callback(std::move(f), std::move(inputs), p);
613 614
    });

M
Megvii Engine Team 已提交
615 616 617 618 619 620 621 622
    m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) {
        auto f = [p](DeviceTensorND dv) {
            HostNDWithEvent hv_with_event;
            hv_with_event.first.copy_from(dv);
            hv_with_event.second = dv.comp_node().create_event();
            hv_with_event.second->record();
            p->set(std::move(hv_with_event));
        };
623
        return output_callback(std::move(f), std::move(inputs), p, true, true);
M
Megvii Engine Team 已提交
624 625
    });

626 627 628 629
    m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) {
        auto f = [p](DeviceTensorND dv) {
            p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()});
        };
630
        return output_callback(std::move(f), std::move(inputs), p, true);
631
    });
632 633 634 635 636 637 638 639 640 641 642 643

    m.def("virtual_dep", [](std::vector<cg::VarNode*> inputs, std::string device) {
        auto&& graph = inputs[0]->owner_graph();
        VarNodeArray inps(inputs.begin(), inputs.end());
        cg::OperatorNodeConfig config;
        if (device.length() > 0) {
            config.comp_node(CompNode::load(device));
        }
        cg::OperatorNodeBase* opr = graph->insert_opr(
                std::make_unique<mgb::opr::VirtualDep>(inps, config));
        return opr;
    });
644
}