From 798f7b3ef66e18f30db18a45d1628164f2407d05 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 18 Nov 2020 11:31:51 +0800 Subject: [PATCH] feat(imperative): add virtual_deps op GitOrigin-RevId: 5845520989ec573943b8910503c8e4c626aa0363 --- imperative/python/src/ops.cpp | 4 ++ imperative/src/impl/ops/utility.cpp | 39 +++++++++++++++++++ .../include/megbrain/imperative/ops/utility.h | 35 +++++++++++++++++ 3 files changed, 78 insertions(+) create mode 100644 imperative/src/impl/ops/utility.cpp create mode 100644 imperative/src/include/megbrain/imperative/ops/utility.h diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index f48c330fd..27846a9db 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -22,6 +22,7 @@ #include "megbrain/imperative/ops/elemwise.h" #include "megbrain/imperative/ops/batch_norm.h" #include "megbrain/imperative/ops/broadcast.h" +#include "megbrain/imperative/ops/utility.h" namespace py = pybind11; @@ -113,6 +114,9 @@ void init_ops(py::module m) { .def(py::init<>()) .def_readwrite("offsets", &ParamPackConcat::offsets); + py::class_, OpDef>(m, "VirtualDep") + .def(py::init<>()); + py::class_, OpDef>(m, "CondTake") .def(py::init<>()); diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp new file mode 100644 index 000000000..b9a701dad --- /dev/null +++ b/imperative/src/impl/ops/utility.cpp @@ -0,0 +1,39 @@ +/** + * \file imperative/src/impl/ops/utility.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/utility.h" +#include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/opr/utility.h" +#include "../op_trait.h" + +namespace mgb::imperative { +namespace { + +cg::OperatorNodeBase* virtual_dep_apply_on_var_node( + const OpDef& def, const VarNodeArray& inputs) { + auto&& graph = inputs[0]->owner_graph(); + + VarNodeArray inps(inputs.begin(), inputs.end()); + cg::OperatorNodeConfig config; + cg::OperatorNodeBase* opr = + graph->insert_opr(std::make_unique( + inps, config)); + return opr; +} + +OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep) + .apply_on_var_node(virtual_dep_apply_on_var_node) + .fallback(); +} // namespace + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h new file mode 100644 index 000000000..817935218 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -0,0 +1,35 @@ +/** + * \file imperative/src/include/megbrain/imperative/ops/utility.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/op_def.h" + +#include "megbrain/utils/hash.h" + +namespace mgb::imperative { + +class VirtualDep : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + VirtualDep() = default; + + size_t hash() const override { + return reinterpret_cast(dyn_typeinfo()); + } + + bool is_same_st(const Hashable& rhs) const override { + return true; + } +}; + +} // namespace mgb::imperative -- GitLab