提交 798f7b3e 编写于 作者: M Megvii Engine Team

feat(imperative): add virtual_deps op

GitOrigin-RevId: 5845520989ec573943b8910503c8e4c626aa0363
上级 1e71e0af
......@@ -22,6 +22,7 @@
#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/batch_norm.h"
#include "megbrain/imperative/ops/broadcast.h"
#include "megbrain/imperative/ops/utility.h"
namespace py = pybind11;
......@@ -113,6 +114,9 @@ void init_ops(py::module m) {
.def(py::init<>())
.def_readwrite("offsets", &ParamPackConcat::offsets);
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep")
.def(py::init<>());
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake")
.def(py::init<>());
......
/**
* \file imperative/src/impl/ops/utility.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace {
cg::OperatorNodeBase* virtual_dep_apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
auto&& graph = inputs[0]->owner_graph();
VarNodeArray inps(inputs.begin(), inputs.end());
cg::OperatorNodeConfig config;
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>(
inps, config));
return opr;
}
OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep)
.apply_on_var_node(virtual_dep_apply_on_var_node)
.fallback();
} // namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep);
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/ops/utility.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
namespace mgb::imperative {
class VirtualDep : public OpDefImplBase<VirtualDep> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
VirtualDep() = default;
size_t hash() const override {
return reinterpret_cast<size_t>(dyn_typeinfo());
}
bool is_same_st(const Hashable& rhs) const override {
return true;
}
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册