diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index f48c330fd0c34a1b2a72dd1893ce911ff45b597f..27846a9db363b9e9d84449528aaae504c1ae93c7 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -22,6 +22,7 @@ #include "megbrain/imperative/ops/elemwise.h" #include "megbrain/imperative/ops/batch_norm.h" #include "megbrain/imperative/ops/broadcast.h" +#include "megbrain/imperative/ops/utility.h" namespace py = pybind11; @@ -113,6 +114,9 @@ void init_ops(py::module m) { .def(py::init<>()) .def_readwrite("offsets", &ParamPackConcat::offsets); + py::class_, OpDef>(m, "VirtualDep") + .def(py::init<>()); + py::class_, OpDef>(m, "CondTake") .def(py::init<>()); diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9a701dadc03f1183b7dbd8e291c69e9ab623fea --- /dev/null +++ b/imperative/src/impl/ops/utility.cpp @@ -0,0 +1,39 @@ +/** + * \file imperative/src/impl/ops/utility.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/utility.h" +#include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/opr/utility.h" +#include "../op_trait.h" + +namespace mgb::imperative { +namespace { + +cg::OperatorNodeBase* virtual_dep_apply_on_var_node( + const OpDef& def, const VarNodeArray& inputs) { + auto&& graph = inputs[0]->owner_graph(); + + VarNodeArray inps(inputs.begin(), inputs.end()); + cg::OperatorNodeConfig config; + cg::OperatorNodeBase* opr = + graph->insert_opr(std::make_unique( + inps, config)); + return opr; +} + +OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep) + .apply_on_var_node(virtual_dep_apply_on_var_node) + .fallback(); +} // namespace + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h new file mode 100644 index 0000000000000000000000000000000000000000..817935218989e2ff646c141a0c8b4cb078adbd1f --- /dev/null +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -0,0 +1,35 @@ +/** + * \file imperative/src/include/megbrain/imperative/ops/utility.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/op_def.h" + +#include "megbrain/utils/hash.h" + +namespace mgb::imperative { + +class VirtualDep : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + VirtualDep() = default; + + size_t hash() const override { + return reinterpret_cast(dyn_typeinfo()); + } + + bool is_same_st(const Hashable& rhs) const override { + return true; + } +}; + +} // namespace mgb::imperative