提交 e6706be2 编写于 作者: M Megvii Engine Team

refactor(imperative): remove infer_output_mem_desc

GitOrigin-RevId: bff62b33a055ef78a0ccdbe92b606edc77872892
上级 a5af35c1
...@@ -12,13 +12,3 @@ from contextlib import contextmanager ...@@ -12,13 +12,3 @@ from contextlib import contextmanager
from ._imperative_rt.core2 import get_option, set_option from ._imperative_rt.core2 import get_option, set_option
from .tensor.megbrain_graph import Graph from .tensor.megbrain_graph import Graph
@contextmanager
def option(key, value):
value = int(value)
old = get_option(key)
set_option(key, value)
yield
assert get_option(key) == value
set_option(key, old)
...@@ -76,10 +76,11 @@ def test_drop_basic(): ...@@ -76,10 +76,11 @@ def test_drop_basic():
def test_finalize(): def test_finalize():
prog = """ prog = """
import megengine import megengine
with megengine.core.option("enable_host_compute", 0): megengine.core.set_option("enable_host_compute", 0)
x = megengine.tensor(0) x = megengine.tensor(0)
y = x + 1 y = x + 1
y.numpy() y.numpy()
megengine.core.set_option("enable_host_compute", 1)
""" """
subprocess.check_call([sys.executable, "-c", prog]) subprocess.check_call([sys.executable, "-c", prog])
......
...@@ -15,7 +15,6 @@ import pytest ...@@ -15,7 +15,6 @@ import pytest
from megengine import Parameter from megengine import Parameter
from megengine import distributed as dist from megengine import distributed as dist
from megengine import tensor from megengine import tensor
from megengine.core import option
from megengine.jit import trace from megengine.jit import trace
from megengine.module import Module from megengine.module import Module
from megengine.utils.profiler import Profiler, scope from megengine.utils.profiler import Profiler, scope
......
...@@ -155,7 +155,6 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { ...@@ -155,7 +155,6 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
info->h_value = value; info->h_value = value;
info->desc.value = value.proxy_to_default_cpu(); info->desc.value = value.proxy_to_default_cpu();
} }
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
m_worker.add_task( m_worker.add_task(
{Profiler::next_id(), Put{info, value, no_cache}, {Profiler::next_id(), Put{info, value, no_cache},
get_channel_state().stack_manager.dump()}); get_channel_state().stack_manager.dump()});
...@@ -180,7 +179,6 @@ TensorInfo* ChannelImpl::put_impl( ...@@ -180,7 +179,6 @@ TensorInfo* ChannelImpl::put_impl(
auto info = alloc(); auto info = alloc();
MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put); MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
init(info, {data.layout(), data.comp_node()}); init(info, {data.layout(), data.comp_node()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->ptr = Tensor::make(data, hvalue); info->ptr = Tensor::make(data, hvalue);
MGB_RECORD_EVENT( MGB_RECORD_EVENT(
TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
...@@ -536,9 +534,6 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) { ...@@ -536,9 +534,6 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
info->status = TensorInfo::Allocated; info->status = TensorInfo::Allocated;
info->desc = std::move(desc); info->desc = std::move(desc);
info->mem_desc.layout = info->desc.layout;
info->mem_desc.cn = info->desc.comp_node;
info->mem_desc.offset = 0;
} }
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) { void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
...@@ -667,18 +662,14 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -667,18 +662,14 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
bool profiling_device = bool profiling_device =
Profiler::is_profiling() && Profiler::get_option("profile_device", 0); Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
uint64_t apply_id = cmd.id; uint64_t apply_id = cmd.id;
struct TensorWithDesc { SmallVector<TensorPtr> inputs;
TensorPtr tensor;
MemoryDesc desc;
};
SmallVector<TensorWithDesc> inputs;
inputs.reserve(cmd.inputs.size()); inputs.reserve(cmd.inputs.size());
// refcnt == 1, owners: [TensorInfo::ptr] // refcnt == 1, owners: [TensorInfo::ptr]
for (auto i : cmd.inputs) { for (auto i : cmd.inputs) {
mgb_assert(i->ptr, "Invalid input tensor ptr!"); mgb_assert(i->ptr, "Invalid input tensor ptr!");
// refcnt ++, owners: [i->ptr, tensor_inputs] // refcnt ++, owners: [i->ptr, tensor_inputs]
// tensor_inputs.push_back(i->ptr); // tensor_inputs.push_back(i->ptr);
inputs.push_back({i->ptr, i->mem_desc}); inputs.push_back(i->ptr);
} }
if (state.options.enable_dtr_auto_drop && if (state.options.enable_dtr_auto_drop &&
state.options.dtr_eviction_threshold > 0) { state.options.dtr_eviction_threshold > 0) {
...@@ -686,56 +677,28 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -686,56 +677,28 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
} }
auto apply_on_physical_tensor = auto apply_on_physical_tensor =
[&](auto&& self, const OpDef& def, [&](auto&& self, const OpDef& def,
SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> { SmallVector<TensorPtr> inputs) -> SmallVector<TensorPtr> {
auto apply_functor = [&](std::shared_ptr<OpDef> op, auto apply_functor = [&](std::shared_ptr<OpDef> op,
SmallVector<TensorWithDesc> inputs, SmallVector<TensorPtr> inputs,
size_t nr_outputs) -> SmallVector<TensorWithDesc> { size_t nr_outputs) -> SmallVector<TensorPtr> {
auto opname = op->trait()->make_name(*op); auto opname = op->trait()->make_name(*op);
imperative_log_profile_begin(opname.c_str()); imperative_log_profile_begin(opname.c_str());
auto outputs = self(self, *op, inputs); auto outputs = self(self, *op, inputs);
imperative_log_profile_end(opname.c_str()); imperative_log_profile_end(opname.c_str());
return outputs; return outputs;
}; };
auto const_functor = [&](TensorPtr value) -> TensorWithDesc { auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
return {value, MemoryDesc{
value->layout(), 0, value->comp_node(),
StorageIdentifier::make()}};
};
if (def.trait()->make_forward_graph) { if (def.trait()->make_forward_graph) {
// apply recursivily // apply recursivily
SmallVector<LogicalTensorDesc> input_descs; SmallVector<LogicalTensorDesc> input_descs;
for (auto&& input : inputs) { for (auto&& input : inputs) {
input_descs.push_back( input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
{{{}, input.tensor->dtype()}, input.tensor->comp_node()});
} }
auto forward_graph = OpDef::make_forward_graph(def, input_descs); auto forward_graph = OpDef::make_forward_graph(def, input_descs);
auto outputs = forward_graph.apply(inputs, apply_functor, const_functor); auto outputs = forward_graph.apply(inputs, apply_functor, const_functor);
return outputs; return outputs;
} }
SmallVector<TensorPtr> input_tensors; return OpDef::apply_on_physical_tensor(def, inputs);
SmallVector<MemoryDesc> input_descs;
for (auto&& input : inputs) {
input_tensors.push_back(input.tensor);
input_descs.push_back(input.desc);
}
auto [output_descs, output_tensors, workspaces] =
init_output_and_workspace(def, input_tensors, input_descs);
if (!output_descs.empty()) {
OpDef::execute(def, input_tensors, output_tensors, workspaces);
} else {
output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors);
for (auto&& output_tensor : output_tensors) {
output_descs.push_back(MemoryDesc{
output_tensor->layout(), 0, output_tensor->comp_node(),
StorageIdentifier::make()});
}
}
SmallVector<TensorWithDesc> outputs;
for (auto&& [output_tensor, output_desc] :
ranges::zip_view(output_tensors, output_descs)) {
outputs.push_back({output_tensor, output_desc});
}
return outputs;
}; };
MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
// Begin profiling operator // Begin profiling operator
...@@ -787,8 +750,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -787,8 +750,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
MGB_RECORD_EVENT(OpOutputFinishEvent, output->id); MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
} else { } else {
MGB_RECORD_EVENT(OpOutputEvent, output->id); MGB_RECORD_EVENT(OpOutputEvent, output->id);
produce_tensor(output, outputs[i].tensor); produce_tensor(output, outputs[i]);
output->mem_desc = outputs[i].desc;
MGB_RECORD_EVENT(OpOutputFinishEvent, output->id); MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
sample_on_device(output->desc.comp_node, false); sample_on_device(output->desc.comp_node, false);
} }
...@@ -800,7 +762,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -800,7 +762,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
estimate_compute_time += i->memory; estimate_compute_time += i->memory;
} }
for (auto i : outputs) { for (auto i : outputs) {
estimate_compute_time += i.tensor->blob()->size(); estimate_compute_time += i->blob()->size();
} }
m_dtr.estimate_timestamp += estimate_compute_time / 1e8; m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
for (auto i : cmd.outputs) { for (auto i : cmd.outputs) {
...@@ -1012,52 +974,6 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) { ...@@ -1012,52 +974,6 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
set_log_level(pre_level); set_log_level(pre_level);
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>>
ChannelImpl::init_output_and_workspace(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<MemoryDesc> inputs_mem_desc) {
auto [outputs_desc, workspaces_desc] =
OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc);
if (!outputs_desc.size()) {
// failed to infer memplan
return {{}, {}, {}};
}
// refine storage id to make it unique
for (auto&& desc : outputs_desc) {
if (desc.id->is_sys_alloc()) {
// TODO: there may be some outputs sharing the same storage id
desc.id->id = ++m_storage_id;
}
}
auto& state = get_worker_state();
auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
SmallVector<TensorPtr> tensors;
for (size_t i = 0; i < desc.size(); i++) {
if (desc[i].id->is_sys_alloc()) {
tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) {
alloc_tensor_with_evict(tensors.back()->blob().get());
}
} else if (desc[i].id->is_from_other()) {
for (size_t j = 0; j < inputs_mem_desc.size(); j++) {
if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
tensors.push_back(
inputs[j]->sub(desc[i].offset, desc[i].layout));
break;
}
}
} else if (desc[i].id->is_device_ptr()) {
tensors.push_back(desc[i].id->ptr);
} else {
mgb_assert(0, "not implemented");
}
}
return tensors;
};
return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)};
}
void ChannelImpl::process_one_task(Command& icmd) { void ChannelImpl::process_one_task(Command& icmd) {
using namespace ranges; using namespace ranges;
using namespace ranges::views; using namespace ranges::views;
......
...@@ -105,11 +105,6 @@ private: ...@@ -105,11 +105,6 @@ private:
void flush_apply_stack(); void flush_apply_stack();
void do_apply_op(const ApplyOp& cmd, std::string reason); void do_apply_op(const ApplyOp& cmd, std::string reason);
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>>
init_output_and_workspace(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<MemoryDesc> inputs_mem_desc);
void dispatch_default_cpu( void dispatch_default_cpu(
std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos, std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs, const SmallVector<LogicalTensorDesc>& input_descs,
...@@ -296,6 +291,8 @@ private: ...@@ -296,6 +291,8 @@ private:
op_blacklist.end(); op_blacklist.end();
} }
// operators that cannot be re-computed, including :
// distributed operators, inplace operator, random generator operators
std::vector<std::string> op_blacklist = { std::vector<std::string> op_blacklist = {
"CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat", "CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat",
"GaussianRNG", "UniformRNG", "GammaRNG", "PermutationRNG", "GaussianRNG", "UniformRNG", "GammaRNG", "PermutationRNG",
......
...@@ -59,7 +59,6 @@ struct TensorInfo { ...@@ -59,7 +59,6 @@ struct TensorInfo {
// Lock interpreter when visiting `ptr`. // Lock interpreter when visiting `ptr`.
TensorPtr ptr; TensorPtr ptr;
LogicalTensorDesc desc; LogicalTensorDesc desc;
MemoryDesc mem_desc;
double compute_time; double compute_time;
size_t memory; size_t memory;
......
...@@ -41,20 +41,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( ...@@ -41,20 +41,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs) { const OpDef& def, SmallVector<TensorPtr> inputs) {
return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); return def.trait()->apply_on_physical_tensor(def, std::move(inputs));
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> OpDef::
infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems);
}
void OpDef::execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace));
}
void OpDef::apply_on_device_tensornd( void OpDef::apply_on_device_tensornd(
const OpDef& def, const SmallVector<DeviceTensorND>& inputs, const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) { SmallVector<DeviceTensorND>* outputs) {
......
...@@ -43,13 +43,6 @@ void OpMethFallbackByProxyGraph::impl( ...@@ -43,13 +43,6 @@ void OpMethFallbackByProxyGraph::impl(
ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) { ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) {
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor);
} }
void OpMethFallbackByProxyGraph::impl(Execute& func, op_meth_tag::Execute) {
func.Base::operator=(proxy_graph_detail::execute);
}
void OpMethFallbackByProxyGraph::impl(
InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc);
}
void OpMethFallbackByProxyGraph::impl( void OpMethFallbackByProxyGraph::impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) { InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible);
...@@ -62,10 +55,6 @@ void OpMethFallbackFromSubgraph::impl( ...@@ -62,10 +55,6 @@ void OpMethFallbackFromSubgraph::impl(
ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) { ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) {
func.Base::operator=(subgraph_detail::apply_on_physical_tensor); func.Base::operator=(subgraph_detail::apply_on_physical_tensor);
} }
void OpMethFallbackFromSubgraph::impl(
InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(subgraph_detail::infer_output_mem_desc);
}
void OpMethFallbackFromSubgraph::impl( void OpMethFallbackFromSubgraph::impl(
ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode) { ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode) {
func.Base::operator=(subgraph_detail::apply_on_var_node); func.Base::operator=(subgraph_detail::apply_on_var_node);
......
...@@ -64,12 +64,6 @@ OpMethType(DecideDispatchMode, ...@@ -64,12 +64,6 @@ OpMethType(DecideDispatchMode,
OpMethType(ApplyOnPhysicalTensor, OpMethType(ApplyOnPhysicalTensor,
decltype(OpDef::apply_on_physical_tensor)); decltype(OpDef::apply_on_physical_tensor));
OpMethType(InferOutputMemDesc,
decltype(OpDef::infer_output_mem_desc));
OpMethType(Execute,
decltype(OpDef::execute));
OpMethType(ApplyOnDeviceTensorND, OpMethType(ApplyOnDeviceTensorND,
decltype(OpDef::apply_on_device_tensornd)); decltype(OpDef::apply_on_device_tensornd));
...@@ -123,8 +117,6 @@ struct OpMethFallback : OpMethImplBase { ...@@ -123,8 +117,6 @@ struct OpMethFallback : OpMethImplBase {
struct OpMethFallbackByProxyGraph : OpMethImplBase { struct OpMethFallbackByProxyGraph : OpMethImplBase {
using OpMethImplBase::impl; using OpMethImplBase::impl;
static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor);
static void impl(Execute& func, op_meth_tag::Execute);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl( static void impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible);
static void impl(GradMaker& func, op_meth_tag::GradMaker); static void impl(GradMaker& func, op_meth_tag::GradMaker);
...@@ -133,7 +125,6 @@ struct OpMethFallbackByProxyGraph : OpMethImplBase { ...@@ -133,7 +125,6 @@ struct OpMethFallbackByProxyGraph : OpMethImplBase {
struct OpMethFallbackFromSubgraph : OpMethImplBase { struct OpMethFallbackFromSubgraph : OpMethImplBase {
using OpMethImplBase::impl; using OpMethImplBase::impl;
static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode); static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode);
static void impl( static void impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible);
...@@ -185,8 +176,6 @@ struct OpTrait { ...@@ -185,8 +176,6 @@ struct OpTrait {
OpDefMaker make_from_op_node; OpDefMaker make_from_op_node;
DecideDispatchMode decide_dispatch_mode; DecideDispatchMode decide_dispatch_mode;
ApplyOnPhysicalTensor apply_on_physical_tensor; ApplyOnPhysicalTensor apply_on_physical_tensor;
InferOutputMemDesc infer_output_mem_desc;
Execute execute;
ApplyOnDeviceTensorND apply_on_device_tensornd; ApplyOnDeviceTensorND apply_on_device_tensornd;
ApplyOnVarNode apply_on_var_node; ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible; InferOutputAttrsFallible infer_output_attrs_fallible;
...@@ -207,8 +196,6 @@ struct OpTrait { ...@@ -207,8 +196,6 @@ struct OpTrait {
cb(make_from_op_node) \ cb(make_from_op_node) \
cb(decide_dispatch_mode) \ cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \ cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \ cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \ cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \ cb(infer_output_attrs_fallible) \
......
...@@ -81,50 +81,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -81,50 +81,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto& input = inputs_tensors[0];
TensorShape target_shape;
cg::copy_tensor_value_to_shape(
target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu());
// TODO: memory forward
// if (input->shape().eq_shape(target_shape)) {
// return {{{input->layout(), 0, input->comp_node(),
// StorageIdentifier::make(&inputs_mems[0])}}, {}};
// }
return {{{{target_shape, input->dtype()},
0,
input->comp_node(),
StorageIdentifier::make(0)}},
{}};
}
void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
if (outputs[0]->layout().is_empty()) {
return;
}
if (inputs[0]->shape().eq_shape(outputs[0]->shape())) {
mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout()));
// TODO: memory forward
// mgb_assert(inputs[0]->offset() == outputs[0]->offset());
// mgb_assert(inputs[0]->blob() == outputs[0]->blob());
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor());
} else {
TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape());
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub(
SubTensorSpec::make_from_layout(input_layout)));
}
}
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.make_from_op_node(make_from_op_node) .make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
} // namespace broadcast } // namespace broadcast
...@@ -187,41 +147,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -187,41 +147,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& op_def = def.cast_final_safe<Reshape>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();
TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op_def.axis] == -1);
tshp[op_def.axis] = 1;
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
}
TensorLayout tlayout = slayout.reshape(tshp);
// memory forward
return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}},
{}};
}
void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
mgb_assert(inputs[0]->offset() == outputs[0]->offset());
mgb_assert(inputs[0]->blob() == outputs[0]->blob());
}
OP_TRAIT_REG(Reshape, Reshape) OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
} // namespace reshape } // namespace reshape
......
...@@ -78,25 +78,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -78,25 +78,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
false}; false};
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}
void execute(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
mgb_assert(0);
}
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
} // namespace } // namespace
......
...@@ -234,12 +234,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -234,12 +234,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return op.infer_output_attrs(inputs); return op.infer_output_attrs(inputs);
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}
size_t hash(const OpDef& def) { size_t hash(const OpDef& def) {
auto&& op = static_cast<const CustomOpDef&>(def); auto&& op = static_cast<const CustomOpDef&>(def);
const custom::Param& param = op.param(); const custom::Param& param = op.param();
...@@ -279,7 +273,6 @@ OP_TRAIT_REG(CustomOpDef, CustomOpDef) ...@@ -279,7 +273,6 @@ OP_TRAIT_REG(CustomOpDef, CustomOpDef)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_device_tensornd(apply_on_device_tensornd)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.hash(hash) .hash(hash)
.is_same_st(is_same_st) .is_same_st(is_same_st)
.props(props) .props(props)
......
...@@ -110,35 +110,6 @@ void apply_on_device_tensornd( ...@@ -110,35 +110,6 @@ void apply_on_device_tensornd(
opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr);
} }
void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
mgb_assert(outputs.size() == 1);
SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
inp_tensornds[i] = inputs[i]->dev_tensor();
}
SmallVector<DeviceTensorND> out_tensornds = {outputs[0]->dev_tensor()};
apply_on_device_tensornd(def, inp_tensornds, &out_tensornds);
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& op_def = def.cast_final_safe<Elemwise>();
TensorShapeArray inp_shapes(inputs_tensors.size());
for (size_t i = 0; i < inputs_tensors.size(); ++i) {
inp_shapes[i] = inputs_tensors[i]->layout();
}
TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
SmallVector<MemoryDesc> outputs = {
{{shape, inputs_tensors[0]->dtype()},
0,
inputs_tensors[0]->comp_node(),
StorageIdentifier::make(1)}};
return {outputs, {}};
}
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>(); auto&& op_def = def.cast_final_safe<Elemwise>();
...@@ -251,7 +222,7 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( ...@@ -251,7 +222,7 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node(
SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
mgb_assert( mgb_assert(
inputs[0]->blob().use_count() == 2 && inputs[0]->blob()->storage().unique(), inputs[0]->blob().use_count() == 1 && inputs[0]->blob()->storage().unique(),
"This inplace modification may change the elements of other tensors. " "This inplace modification may change the elements of other tensors. "
"Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs " "Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs "
"correctly."); "correctly.");
...@@ -265,23 +236,6 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( ...@@ -265,23 +236,6 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())}; return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())};
} }
void execute_inplace(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
apply_inplace_add_on_physical_tensor(def, inputs);
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
infer_inplace_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto dest = inputs_tensors[0];
SmallVector<MemoryDesc> outputs = {
{dest->layout(), 0, dest->comp_node(),
StorageIdentifier::make(&inputs_mems[0])}};
return {outputs, {}};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(inputs.size() == 4, "invalid input number for inplace_add"); mgb_assert(inputs.size() == 4, "invalid input number for inplace_add");
...@@ -319,16 +273,12 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) ...@@ -319,16 +273,12 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate) OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
.apply_on_var_node(apply_inplace_add_on_var_node) .apply_on_var_node(apply_inplace_add_on_var_node)
.apply_on_physical_tensor(apply_inplace_add_on_physical_tensor) .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
.infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible) .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
.infer_output_mem_desc(infer_inplace_output_mem_desc)
.execute(execute_inplace)
.fallback(); .fallback();
} // anonymous namespace } // anonymous namespace
......
...@@ -75,16 +75,11 @@ SmallVector<LogicalTensorDesc> infer_output_attrs( ...@@ -75,16 +75,11 @@ SmallVector<LogicalTensorDesc> infer_output_attrs(
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32());
return dests; return dests;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) OP_TRAIT_REG(CheckNonFinite, CheckNonFinite)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.fallback(); .fallback();
} // namespace check_non_finite } // namespace check_non_finite
......
...@@ -36,6 +36,7 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { ...@@ -36,6 +36,7 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
return Reduce::make(node->param()); return Reduce::make(node->param());
} }
// TODO: using this for apply_on_physical_tensor
bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) {
auto&& reduce = static_cast<const Reduce&>(def); auto&& reduce = static_cast<const Reduce&>(def);
if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) { if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) {
...@@ -49,31 +50,9 @@ bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { ...@@ -49,31 +50,9 @@ bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) {
return false; return false;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
if (memory_forward_success(def, inputs_tensors)) {
auto& src_desc = inputs_mems[0];
return {{{src_desc.layout, 0, src_desc.cn, StorageIdentifier::make(&src_desc)}},
{}};
}
return proxy_graph_detail::infer_output_mem_desc(def, inputs_tensors, inputs_mems);
}
void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
if (memory_forward_success(def, inputs)) {
return;
}
return proxy_graph_detail::execute(def, inputs, outputs, workspace);
}
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node) .make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
} // namespace reduce } // namespace reduce
} // namespace } // namespace
......
...@@ -517,20 +517,6 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>( ...@@ -517,20 +517,6 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
return dests; return dests;
} }
template <typename Op>
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& dests = infer_output_attrs<Op>(def, inputs_tensors);
SmallVector<MemoryDesc> outputs;
for (size_t i = 0; i < dests.size(); ++i) {
outputs.push_back(
{dests[i].layout, 0, dests[i].comp_node,
StorageIdentifier::make(i + 1)});
}
return {outputs, {}};
}
template <typename Op> template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
...@@ -543,13 +529,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -543,13 +529,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return outputs; return outputs;
} }
template <typename Op>
void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
exec<Op>(def, inputs, outputs, {});
}
template <typename Op, typename Output> template <typename Op, typename Output>
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
...@@ -641,8 +620,6 @@ CompNode get_rng_handle_compnode(Handle handle) { ...@@ -641,8 +620,6 @@ CompNode get_rng_handle_compnode(Handle handle) {
.apply_on_var_node(apply_on_var_node<NAME, Output>) \ .apply_on_var_node(apply_on_var_node<NAME, Output>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \
.execute(execute<NAME>) \
.fallback(); \ .fallback(); \
} }
......
...@@ -141,39 +141,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -141,39 +141,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
HostTensorND tensor = get_var_shape_host_tensor(def, inputs);
SmallVector<MemoryDesc> ret;
auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor);
if (blob) {
ret.push_back(
{tensor.layout(), 0, inputs[0]->comp_node(),
StorageIdentifier::make(Tensor::make(
std::forward<decltype(blob)>(blob), tensor.layout(),
tensor))});
} else {
ret.push_back(
{tensor.layout(), 0, inputs[0]->comp_node(),
StorageIdentifier::make(1)});
}
return {ret, {}};
}
void execute(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
HostTensorND tensor = get_var_shape_host_tensor(def, inputs);
SmallVector<MemoryDesc> ret;
auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor);
if (!blob || blob->storage() != outputs[0]->blob()->storage()) {
outputs[0]->dev_tensor().copy_from_fixlayout(tensor);
AsyncReleaser::inst()->add(tensor);
}
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::GetVarShape>(); auto* node = &node_->cast_final_safe<opr::GetVarShape>();
return GetVarShape::make(node->param()); return GetVarShape::make(node->param());
...@@ -186,8 +153,6 @@ OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) ...@@ -186,8 +153,6 @@ OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
} // namespace get_var_shape } // namespace get_var_shape
...@@ -215,38 +180,6 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( ...@@ -215,38 +180,6 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
return opr; return opr;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
param_pack_split_infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& param = def.cast_final_safe<ParamPackSplit>();
mgb_assert(
inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp->layout();
mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
mgb_assert(param.shapes.size() * 2 == param.offsets.size());
SmallVector<MemoryDesc> ret;
auto&& shapes = get_shapes(param.shapes);
size_t dtype_size = inputs[0]->layout().dtype.size();
for (size_t i = 0; i < shapes.size(); ++i) {
// memory forward
ret.push_back(
{{shapes[i], inputs[0]->dtype()},
param.offsets[i * 2] * dtype_size,
inp->comp_node(),
StorageIdentifier::make(&inputs_mems[0])});
}
return {ret, {}};
}
void param_pack_split_execute(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
// do nothing
}
SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto&& param = def.cast_final_safe<ParamPackSplit>(); auto&& param = def.cast_final_safe<ParamPackSplit>();
...@@ -268,8 +201,6 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( ...@@ -268,8 +201,6 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit) OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
.apply_on_var_node(param_pack_split_apply_on_var_node) .apply_on_var_node(param_pack_split_apply_on_var_node)
.infer_output_mem_desc(param_pack_split_infer_output_mem_desc)
.execute(param_pack_split_execute)
.apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor) .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
.fallback(); .fallback();
...@@ -286,75 +217,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( ...@@ -286,75 +217,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
return opr; return opr;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
param_pack_concat_infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
auto dtype = inputs.front()->dtype();
size_t nr_inputs = inputs.size() - 1;
size_t nr_elems = 0;
for (size_t i = 0; i < nr_inputs; ++i) {
auto& input = inputs[i];
mgb_assert(
comp_node == input->comp_node(),
"inputs for param_pack_concat must in same comp_node");
mgb_assert(
dtype == input->dtype(),
"inputs for param_pack_concat must have same dtype");
nr_elems += input->layout().total_nr_elems();
}
auto dest_layout = TensorLayout({nr_elems}, dtype);
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t ws_size;
{
TensorShapeArray src_shapes;
for (size_t i = 0; i < nr_inputs; ++i) {
src_shapes.push_back(inputs[i]->shape());
}
ws_size = caller.op->get_workspace_in_bytes(
src_shapes, inputs.back()->shape(), TensorShape{});
}
SmallVector<MemoryDesc> outputs = {
{dest_layout, 0, comp_node, StorageIdentifier::make(1)}};
MemoryDesc workspace = {
{{ws_size}, dtype::Byte()}, 0, comp_node, StorageIdentifier::make(2)};
return {outputs, {workspace}};
}
void param_pack_concat_execute(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
size_t nr_inputs = inputs.size() - 1;
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t srcs_size = sizeof(void*) * nr_inputs;
void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
std::shared_ptr<dt_byte> srcs_ptr = {
(dt_byte*)srcs_raw_ptr,
[comp_node](dt_byte* ptr) { comp_node.free_host(ptr); }};
TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
for (size_t i = 0; i < nr_inputs; ++i) {
srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr();
}
HostTensorStorage srcs_storage;
srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
megdnn::Workspace dnn_wk(
workspace[0]->blob()->storage().get(), workspace[0]->blob()->size());
caller.op->exec(
{srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(),
outputs[0]->dev_tensor().as_megdnn(), dnn_wk);
AsyncReleaser::inst()->add(
HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
}
SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
def.cast_final_safe<ParamPackConcat>(); def.cast_final_safe<ParamPackConcat>();
...@@ -407,8 +269,6 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( ...@@ -407,8 +269,6 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.apply_on_var_node(param_pack_concat_apply_on_var_node) .apply_on_var_node(param_pack_concat_apply_on_var_node)
.infer_output_mem_desc(param_pack_concat_infer_output_mem_desc)
.execute(param_pack_concat_execute)
.apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor) .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
.fallback(); .fallback();
} // namespace param_pack } // namespace param_pack
......
...@@ -445,12 +445,6 @@ auto make_name(const OpDef& def) { ...@@ -445,12 +445,6 @@ auto make_name(const OpDef& def) {
return ssprintf("CompiledOp[%s]", op.op->make_name().c_str()); return ssprintf("CompiledOp[%s]", op.op->make_name().c_str());
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {};
}
EncodedSubgraph make_backward_graph( EncodedSubgraph make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
...@@ -498,7 +492,6 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) ...@@ -498,7 +492,6 @@ OP_TRAIT_REG(CompiledOp, CompiledOp)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.make_backward_graph(make_backward_graph) .make_backward_graph(make_backward_graph)
.make_name(make_name) .make_name(make_name)
.infer_output_mem_desc(infer_output_mem_desc)
.props(props) .props(props)
.hash(hash) .hash(hash)
.is_same_st(is_same_st) .is_same_st(is_same_st)
......
...@@ -634,36 +634,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph:: ...@@ -634,36 +634,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::
mgb_assert(0); mgb_assert(0);
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
infer_output_mem_desc(
const OpDef& def, const SmallVector<Tensor*>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto opr = get_proxy_opr(def, inputs_tensors);
CUR_OPR_GUARD(opr);
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_graph.get(), ProxyGraph::get_workspace_limit);
do_shape_infer(true);
SmallVector<MemoryDesc> outputs;
SmallVector<MemoryDesc> workspaces;
size_t cur_id = 0;
for (auto&& i : opr->output()) {
if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
workspaces.push_back(
{{i->shape(), i->dtype(), i->format()},
0,
i->comp_node(),
StorageIdentifier::make(++cur_id)});
} else {
outputs.push_back(
{{i->shape(), i->dtype()},
0,
i->comp_node(),
StorageIdentifier::make(++cur_id)});
}
}
return {outputs, workspaces};
}
struct ProxyGraph::GradGraph { struct ProxyGraph::GradGraph {
cg::VarNodeArray inputs; cg::VarNodeArray inputs;
cg::VarNodeArray outputs; cg::VarNodeArray outputs;
...@@ -812,7 +782,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph( ...@@ -812,7 +782,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
return result; return result;
} }
VarNodeArray ProxyGraph::make_input_place_holders( VarNodeArray ProxyGraph::make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
VarNodeArray vinputs(inputs.size()); VarNodeArray vinputs(inputs.size());
......
...@@ -47,10 +47,6 @@ public: ...@@ -47,10 +47,6 @@ public:
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad); const SmallVector<bool>& output_has_grad);
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<Tensor*>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
/********************** Logical Tensor API **********************/ /********************** Logical Tensor API **********************/
size_t get_opr_output_size( size_t get_opr_output_size(
......
...@@ -83,25 +83,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -83,25 +83,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return outputs; return outputs;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_mem_desc(
def, to_raw_ptr_array(inputs_tensors), inputs_mems);
}
void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
exec(def, inputs, outputs, workspace);
auto async_error = ProxyGraph::get_async_error();
if (async_error) {
throw *async_error;
}
return;
}
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const // std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const
// OpDef& def, // OpDef& def,
// const SmallVector<LogicalTensorDesc>& inputs) { // const SmallVector<LogicalTensorDesc>& inputs) {
......
...@@ -162,12 +162,6 @@ EncodedSubgraph make_backward_graph( ...@@ -162,12 +162,6 @@ EncodedSubgraph make_backward_graph(
inputs, input_requires_grad, output_has_grad, forward_graph); inputs, input_requires_grad, output_has_grad, forward_graph);
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}
} // namespace subgraph_detail } // namespace subgraph_detail
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
...@@ -53,10 +53,6 @@ public: ...@@ -53,10 +53,6 @@ public:
static SmallVector<TensorPtr> apply_on_physical_tensor( static SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs); const OpDef& def, SmallVector<TensorPtr> inputs);
static void execute(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs, SmallVector<TensorPtr> workspace);
/*! /*!
* \brief Call the corresponding dnn op to calculate results. Output * \brief Call the corresponding dnn op to calculate results. Output
* tensors' device memory should be allocated outside. * tensors' device memory should be allocated outside.
...@@ -71,11 +67,6 @@ public: ...@@ -71,11 +67,6 @@ public:
static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs);
static std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
static EncodedSubgraph make_backward_graph( static EncodedSubgraph make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
......
...@@ -288,36 +288,6 @@ struct LogicalTensorDesc { ...@@ -288,36 +288,6 @@ struct LogicalTensorDesc {
CompNode comp_node; CompNode comp_node;
DeviceTensorND value; // cpu:default DeviceTensorND value; // cpu:default
}; };
struct StorageIdentifier;
struct MemoryDesc {
TensorLayout layout;
size_t offset;
CompNode cn;
std::shared_ptr<StorageIdentifier> id;
};
struct StorageIdentifier {
enum { INVALID, SYS_ALLOC, FROM_OTHER, DEVICE_PTR } tag;
union {
size_t id;
MemoryDesc* desc;
};
TensorPtr ptr;
StorageIdentifier() = default;
StorageIdentifier(size_t id) : tag(SYS_ALLOC), id(id) {}
StorageIdentifier(const MemoryDesc* desc) : tag(FROM_OTHER), desc(desc->id->desc) {}
StorageIdentifier(TensorPtr dev_ptr) : tag(DEVICE_PTR), ptr(dev_ptr) {}
template <typename... Args>
static std::shared_ptr<StorageIdentifier> make(Args&&... args) {
return std::make_shared<StorageIdentifier>(std::forward<Args>(args)...);
}
bool is_sys_alloc() { return tag == SYS_ALLOC; }
bool is_from_other() { return tag == FROM_OTHER; }
bool is_device_ptr() { return tag == DEVICE_PTR; }
bool is_invalid() { return tag == INVALID; }
};
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -20,17 +20,9 @@ namespace proxy_graph_detail { ...@@ -20,17 +20,9 @@ namespace proxy_graph_detail {
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs); const OpDef& def, SmallVector<TensorPtr> inputs);
void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace);
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs);
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
void exec( void exec(
const OpDef& def, const SmallVector<TensorPtr>& inputs, const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs); const SmallVector<TensorPtr>& outputs);
......
...@@ -35,10 +35,6 @@ EncodedSubgraph make_backward_graph( ...@@ -35,10 +35,6 @@ EncodedSubgraph make_backward_graph(
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad); const SmallVector<bool>& output_has_grad);
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
} // namespace subgraph_detail } // namespace subgraph_detail
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册