提交 b4581788 编写于 作者: M Megvii Engine Team

feat(opr): add mutable tensor opr

GitOrigin-RevId: 7f8a3d7b661c18fb407047a78b965630b52e61d9
上级 47fe7663
......@@ -271,6 +271,69 @@ void NopCallback::do_execute(ExecEnv& env) {
env.dispatch_on_comp_node(cn, runner);
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MutableTensor);
MutableTensor::MutableTensor(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor, const OperatorNodeConfig& config)
: Super(&graph, config, {}, {}) {
m_dev_tensor = dev_tensor;
m_host_tensor = host_tensor;
add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.dtype(m_dev_tensor->dtype());
add_equivalence_component<ScalarHash<const void*>>(this);
}
SymbolVar MutableTensor::make(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor, const OperatorNodeConfig& config) {
return graph
.insert_opr(std::make_unique<MutableTensor>(
graph, dev_tensor, host_tensor, config))
->output(0);
}
void MutableTensor::init_output_comp_node() {
if (config().has_comp_node_set()) {
mgb_assert(
config().get_single_comp_node() == m_dev_tensor->comp_node(),
"comp_node mismatch");
}
comp_node(m_dev_tensor->comp_node());
}
cg::OperatorNodeBase::NodeProp* MutableTensor::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_flag(NodeProp::Flag::IMPURE_OUTPUT_MEM_PLAN);
return ret;
}
void MutableTensor::scn_do_execute() {
output(0)->reset_dev_tensor_from_tensor(*m_dev_tensor);
}
void MutableTensor::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto& mgr = owner_graph()->static_infer_manager();
auto infer_shape = [this](TensorShape& dest, const InpVal&) {
dest = m_dev_tensor->shape();
return true;
};
mgr.register_shape_infer(output(0), {SourceType::MUTABLE, {}, infer_shape});
if (m_host_tensor) {
auto infer_value = [this](DeviceTensorND& dest, const InpVal&) {
if (!m_host_tensor->layout().ndim) {
return false;
}
dest = m_host_tensor->proxy_to_default_cpu();
return true;
};
mgr.register_value_infer(output(0), {SourceType::MUTABLE, {}, infer_value});
}
}
} // namespace opr
} // namespace mgb
......
......@@ -16,6 +16,7 @@
#include "megbrain/opr/internal/identical_fwd.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/param_tag_defs.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/serialization/sereg.h"
......@@ -106,6 +107,28 @@ protected:
private:
callback_t m_callback;
};
MGB_DEFINE_OPR_CLASS(MutableTensor, cg::SingleCNOperatorNodeBase) // {
public:
MutableTensor(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor,
const OperatorNodeConfig& config);
static SymbolVar make(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor = {},
const OperatorNodeConfig& config = {});
protected:
void init_output_comp_node() override;
void init_output_static_infer_desc() override;
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override;
void scn_do_execute() override;
private:
std::shared_ptr<DeviceTensorND> m_dev_tensor;
std::shared_ptr<HostTensorND> m_host_tensor;
};
} // namespace opr
} // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册