提交 69d1fd0f 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(opdef): split apply_on_physical_tensor into infer_output_mem_desc and execute

GitOrigin-RevId: 4d62b7cbbd5289df9f9603fba62d31eece5856b5
上级 75eb04c5
...@@ -124,6 +124,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { ...@@ -124,6 +124,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
auto info = alloc(); auto info = alloc();
init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()}); init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->h_value = value; info->h_value = value;
m_buffer.enqueue(Put{info, value, no_cache}); m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) { if (m_async_level == 0) {
...@@ -141,6 +142,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { ...@@ -141,6 +142,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
auto info = alloc(); auto info = alloc();
RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put); RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put);
init(info, {data.layout(), data.comp_node()}); init(info, {data.layout(), data.comp_node()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->ptr = Tensor::make(data); info->ptr = Tensor::make(data);
RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr()); RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
info->status = TensorInfo::Produced; info->status = TensorInfo::Produced;
...@@ -487,6 +489,9 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) { ...@@ -487,6 +489,9 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
RECORD_EVENT(TensorDeclareEvent, info->id, info->name); RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
info->status = TensorInfo::Allocated; info->status = TensorInfo::Allocated;
info->desc = std::move(desc); info->desc = std::move(desc);
info->mem_desc.layout = info->desc.layout;
info->mem_desc.cn = info->desc.comp_node;
info->mem_desc.offset = 0;
} }
...@@ -605,6 +610,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { ...@@ -605,6 +610,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
bool profiling_device = Profiler::is_profiling() && Profiler::get_option("profile_device", 0); bool profiling_device = Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
uint64_t apply_id = cmd.id; uint64_t apply_id = cmd.id;
SmallVector<TensorPtr> tensor_inputs; SmallVector<TensorPtr> tensor_inputs;
SmallVector<MemoryDesc> input_memory_desc;
if (state.options.enable_dtr_auto_drop) { if (state.options.enable_dtr_auto_drop) {
m_dtr.pin(cmd.inputs); m_dtr.pin(cmd.inputs);
} }
...@@ -618,8 +624,27 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { ...@@ -618,8 +624,27 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
// refcnt == 1, owners: [TensorInfo::ptr] // refcnt == 1, owners: [TensorInfo::ptr]
for (auto i : cmd.inputs) { for (auto i : cmd.inputs) {
mgb_assert(i->ptr, "Invalid input tensor ptr!"); mgb_assert(i->ptr, "Invalid input tensor ptr!");
mgb_assert(i->mem_desc.id, "Invalid input tensor mem desc!");
// refcnt ++, owners: [i->ptr, tensor_inputs] // refcnt ++, owners: [i->ptr, tensor_inputs]
tensor_inputs.push_back(i->ptr); tensor_inputs.push_back(i->ptr);
input_memory_desc.push_back(i->mem_desc);
}
// SmallVector<MemoryDesc> outputs_mem_desc;
// SmallVector<TensorPtr> tensor_outputs, workspaces;
auto [outputs_mem_desc, tensor_outputs, workspaces] = init_output_and_workspace(*cmd.op, tensor_inputs, input_memory_desc);
if (outputs_mem_desc.size()) {
for (size_t i = 0;i < outputs_mem_desc.size();i ++) {
if (cmd.outputs[i]) {
cmd.outputs[i]->mem_desc = outputs_mem_desc[i];
}
}
} else {
// fail to infer mem plan
for (auto && out : cmd.outputs) {
if (out) {
out->mem_desc.id = StorageIdentifier::make();
}
}
} }
RECORD_EVENT(OpExecuteEvent, apply_id); RECORD_EVENT(OpExecuteEvent, apply_id);
// Begin profiling operator // Begin profiling operator
...@@ -662,8 +687,13 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { ...@@ -662,8 +687,13 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
} }
// Apply op // Apply op
// Here std::move is REQUIRED for removing duplicated references. // Here std::move is REQUIRED for removing duplicated references.
auto tensor_outputs = OpDef::apply_on_physical_tensor( if (outputs_mem_desc.size()) {
OpDef::execute(
*cmd.op, std::move(tensor_inputs), tensor_outputs, std::move(workspaces));
} else {
tensor_outputs = OpDef::apply_on_physical_tensor(
*cmd.op, std::move(tensor_inputs)); *cmd.op, std::move(tensor_inputs));
}
// After execute // After execute
for (auto&& [device, kernel_id]: kernels) { for (auto&& [device, kernel_id]: kernels) {
RECORD_EVENT(KernelExecuteFinishEvent, apply_id, kernel_id, Timer::record_event(device)); RECORD_EVENT(KernelExecuteFinishEvent, apply_id, kernel_id, Timer::record_event(device));
...@@ -829,6 +859,47 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() { ...@@ -829,6 +859,47 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
return valid_tensors; return valid_tensors;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> ChannelImpl::init_output_and_workspace(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<MemoryDesc> inputs_mem_desc) {
auto [outputs_desc, workspaces_desc] = OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc);
if (!outputs_desc.size()) {
// failed to infer memplan
return {{}, {}, {}};
}
// refine storage id to make it unique
for (auto&& desc : outputs_desc) {
if (desc.id->is_sys_alloc()) {
// TODO: there may be some outputs sharing the same storage id
desc.id->id = ++ m_storage_id;
}
}
auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
SmallVector<TensorPtr> tensors;
for (size_t i = 0; i < desc.size(); i ++) {
if (desc[i].id->is_sys_alloc()) {
tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
} else if (desc[i].id->is_from_other()) {
for (size_t j = 0; j < inputs_mem_desc.size();j ++) {
if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
tensors.push_back(inputs[j]->sub(desc[i].offset, desc[i].layout));
break;
}
}
} else if (desc[i].id->is_device_ptr()) {
tensors.push_back(desc[i].id->ptr);
} else {
mgb_assert(0, "not implemented");
}
}
return tensors;
};
return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)};
}
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
using namespace ranges; using namespace ranges;
using namespace ranges::views; using namespace ranges::views;
......
...@@ -102,6 +102,11 @@ private: ...@@ -102,6 +102,11 @@ private:
void recompute(TensorInfo::ComputePath* path); void recompute(TensorInfo::ComputePath* path);
void do_apply_op(const ApplyOp& cmd); void do_apply_op(const ApplyOp& cmd);
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> init_output_and_workspace(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<MemoryDesc> inputs_mem_desc);
void dispatch_default_cpu( void dispatch_default_cpu(
std::shared_ptr<OpDef> op, std::shared_ptr<OpDef> op,
const SmallVector<TensorInfo*>& input_infos, const SmallVector<TensorInfo*>& input_infos,
...@@ -139,6 +144,7 @@ private: ...@@ -139,6 +144,7 @@ private:
uint64_t m_waitee_id = 0; uint64_t m_waitee_id = 0;
std::exception_ptr m_worker_exc; std::exception_ptr m_worker_exc;
std::function<void(std::string, std::string)> m_profile_dump_callback; std::function<void(std::string, std::string)> m_profile_dump_callback;
size_t m_storage_id = 0;
bool m_closed = false; bool m_closed = false;
......
...@@ -58,6 +58,7 @@ struct TensorInfo { ...@@ -58,6 +58,7 @@ struct TensorInfo {
// Lock interpreter when visiting `ptr`. // Lock interpreter when visiting `ptr`.
TensorPtr ptr; TensorPtr ptr;
LogicalTensorDesc desc; LogicalTensorDesc desc;
MemoryDesc mem_desc;
double compute_time; double compute_time;
size_t memory; size_t memory;
......
...@@ -45,6 +45,21 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( ...@@ -45,6 +45,21 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); return def.trait()->apply_on_physical_tensor(def, std::move(inputs));
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> OpDef::infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems);
}
void OpDef::execute(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace));
}
void OpDef::apply_on_device_tensornd( void OpDef::apply_on_device_tensornd(
const OpDef& def, const OpDef& def,
const SmallVector<DeviceTensorND>& inputs, const SmallVector<DeviceTensorND>& inputs,
......
...@@ -79,6 +79,13 @@ OpTraitRegistry& OpTraitRegistry::fallback() { ...@@ -79,6 +79,13 @@ OpTraitRegistry& OpTraitRegistry::fallback() {
trait->apply_on_physical_tensor = trait->apply_on_physical_tensor =
proxy_graph_detail::apply_on_physical_tensor; proxy_graph_detail::apply_on_physical_tensor;
} }
if (!trait->execute) {
trait->execute = proxy_graph_detail::execute;
}
if (!trait->infer_output_mem_desc) {
trait->infer_output_mem_desc =
proxy_graph_detail::infer_output_mem_desc;
}
if (!trait->infer_output_attrs_fallible) { if (!trait->infer_output_attrs_fallible) {
trait->infer_output_attrs_fallible = trait->infer_output_attrs_fallible =
proxy_graph_detail::infer_output_attrs_fallible; proxy_graph_detail::infer_output_attrs_fallible;
......
...@@ -64,6 +64,10 @@ using DecideDispatchMode = detail::OpMeth< ...@@ -64,6 +64,10 @@ using DecideDispatchMode = detail::OpMeth<
decltype(OpDef::decide_dispatch_mode)>; decltype(OpDef::decide_dispatch_mode)>;
using ApplyOnPhysicalTensor = detail::OpMeth< using ApplyOnPhysicalTensor = detail::OpMeth<
decltype(OpDef::apply_on_physical_tensor)>; decltype(OpDef::apply_on_physical_tensor)>;
using InferOutputMemDesc = detail::OpMeth<
decltype(OpDef::infer_output_mem_desc)>;
using Execute = detail::OpMeth<
decltype(OpDef::execute)>;
using ApplyOnDeviceTensorND = detail::OpMeth< using ApplyOnDeviceTensorND = detail::OpMeth<
decltype(OpDef::apply_on_device_tensornd)>; decltype(OpDef::apply_on_device_tensornd)>;
using ApplyOnVarNode = detail::OpMeth< using ApplyOnVarNode = detail::OpMeth<
...@@ -82,6 +86,8 @@ struct OpTrait { ...@@ -82,6 +86,8 @@ struct OpTrait {
OpDefMaker make_from_op_node; OpDefMaker make_from_op_node;
DecideDispatchMode decide_dispatch_mode; DecideDispatchMode decide_dispatch_mode;
ApplyOnPhysicalTensor apply_on_physical_tensor; ApplyOnPhysicalTensor apply_on_physical_tensor;
InferOutputMemDesc infer_output_mem_desc;
Execute execute;
ApplyOnDeviceTensorND apply_on_device_tensornd; ApplyOnDeviceTensorND apply_on_device_tensornd;
ApplyOnVarNode apply_on_var_node; ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible; InferOutputAttrsFallible infer_output_attrs_fallible;
...@@ -100,6 +106,8 @@ struct OpTrait { ...@@ -100,6 +106,8 @@ struct OpTrait {
cb(make_from_op_node) \ cb(make_from_op_node) \
cb(decide_dispatch_mode) \ cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \ cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \ cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \ cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \ cb(infer_output_attrs_fallible) \
......
...@@ -79,10 +79,24 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -79,10 +79,24 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}, false}; }, false};
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}
void execute(const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs, const SmallVector<TensorPtr>& workspace) {
mgb_assert(0);
}
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
} // namespace } // namespace
......
...@@ -118,6 +118,35 @@ void apply_on_device_tensornd( ...@@ -118,6 +118,35 @@ void apply_on_device_tensornd(
opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr);
} }
void execute(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
mgb_assert(outputs.size() == 1);
SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
for (size_t i = 0;i < inputs.size(); ++i) {
inp_tensornds[i] = inputs[i]->dev_tensor();
}
SmallVector<DeviceTensorND> out_tensornds = {outputs[0]->dev_tensor()};
apply_on_device_tensornd(def, inp_tensornds, &out_tensornds);
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& op_def = def.cast_final_safe<Elemwise>();
TensorShapeArray inp_shapes(inputs_tensors.size());
for (size_t i = 0;i < inputs_tensors.size(); ++i) {
inp_shapes[i] = inputs_tensors[i]->layout();
}
TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
SmallVector<MemoryDesc> outputs = {{{shape, inputs_tensors[0]->dtype()}, 0, inputs_tensors[0]->comp_node(), StorageIdentifier::make(1)}};
return {outputs, {}};
}
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const OpDef& def,
const SmallVector<TensorPtr>& inputs) { const SmallVector<TensorPtr>& inputs) {
...@@ -224,7 +253,7 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( ...@@ -224,7 +253,7 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node(
SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
const OpDef& def, const OpDef& def,
const SmallVector<TensorPtr>& inputs){ const SmallVector<TensorPtr>& inputs){
mgb_assert(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique(), mgb_assert(inputs[0]->blob().use_count() == 2 && inputs[0]->blob()->storage().unique(),
"This inplace modification may change the elements of other tensors. " "This inplace modification may change the elements of other tensors. "
"Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs correctly."); "Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs correctly.");
auto dest = inputs[0], delta = inputs[1], auto dest = inputs[0], delta = inputs[1],
...@@ -238,6 +267,23 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( ...@@ -238,6 +267,23 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
return { std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout()) }; return { std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout()) };
} }
void execute_inplace(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
apply_inplace_add_on_physical_tensor(def, inputs);
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_inplace_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto dest = inputs_tensors[0];
SmallVector<MemoryDesc> outputs = {{dest->layout(), 0, dest->comp_node(), StorageIdentifier::make(&inputs_mems[0])}};
return {outputs, {}};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
...@@ -271,12 +317,16 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) ...@@ -271,12 +317,16 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate) OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
.apply_on_var_node(apply_inplace_add_on_var_node) .apply_on_var_node(apply_inplace_add_on_var_node)
.apply_on_physical_tensor(apply_inplace_add_on_physical_tensor) .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
.infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible) .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
.infer_output_mem_desc(infer_inplace_output_mem_desc)
.execute(execute_inplace)
.fallback(); .fallback();
} // anonymous namespace } // anonymous namespace
......
...@@ -331,6 +331,7 @@ struct _RNGOprInvoker<DNN_NR_INPUTS> { ...@@ -331,6 +331,7 @@ struct _RNGOprInvoker<DNN_NR_INPUTS> {
} \ } \
}; };
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \ #define _INST_RNG_MAKER(MGB_NR_INPUTS) \
template<> \ template<> \
struct _RNGOprMaker<MGB_NR_INPUTS> { \ struct _RNGOprMaker<MGB_NR_INPUTS> { \
...@@ -366,7 +367,7 @@ _INST_RNG_MAKER(2) ...@@ -366,7 +367,7 @@ _INST_RNG_MAKER(2)
template <typename Op> template <typename Op>
void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) { const SmallVector<TensorPtr>& outputs, const SmallVector<TensorPtr>& workspace) {
auto&& rng = op.cast_final_safe<Op>(); auto&& rng = op.cast_final_safe<Op>();
auto dest = outputs[0]; auto dest = outputs[0];
...@@ -418,6 +419,18 @@ SmallVector<LogicalTensorDesc> infer_output_attrs( ...@@ -418,6 +419,18 @@ SmallVector<LogicalTensorDesc> infer_output_attrs(
return {dest}; return {dest};
} }
template <typename Op>
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto &&dest = infer_output_attrs<Op>(def, inputs_tensors);
SmallVector<MemoryDesc> outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}};
return {outputs, {}};
}
template <typename Op> template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
...@@ -427,10 +440,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -427,10 +440,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
for (auto&& i : desc) { for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node)); outputs.push_back(Tensor::make(i.layout, i.comp_node));
} }
exec<Op>(def, inputs, outputs); exec<Op>(def, inputs, outputs, {});
return outputs; return outputs;
} }
template <typename Op>
void execute(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
exec<Op>(def, inputs, outputs, {});
}
template<typename Op> template<typename Op>
SymbolVar apply_on_var_node( SymbolVar apply_on_var_node(
const OpDef& def, const OpDef& def,
...@@ -492,6 +514,8 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ ...@@ -492,6 +514,8 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.apply_on_var_node(apply_on_var_node<NAME>) \ .apply_on_var_node(apply_on_var_node<NAME>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \
.execute(execute<NAME>) \
.fallback(); \ .fallback(); \
} \ } \
......
...@@ -86,22 +86,22 @@ void apply_on_device_tensornd( ...@@ -86,22 +86,22 @@ void apply_on_device_tensornd(
(*outputs)[0] = DeviceTensorND::make_proxy(hv); (*outputs)[0] = DeviceTensorND::make_proxy(hv);
} }
SmallVector<TensorPtr> apply_on_physical_tensor( HostTensorND get_var_shape_host_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) {
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
SmallVector<DeviceTensorND> input_tensornds; SmallVector<DeviceTensorND> input_tensornds;
input_tensornds.reserve(inputs.size()); input_tensornds.reserve(inputs.size());
for (auto&& inp : inputs) { for (auto&& inp : inputs) {
input_tensornds.push_back(inp->dev_tensor()); input_tensornds.push_back(inp->dev_tensor());
} }
SmallVector<DeviceTensorND> output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}}; SmallVector<DeviceTensorND> output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}};
apply_on_device_tensornd(def, input_tensornds, &output_tensornds); apply_on_device_tensornd(def, input_tensornds, &output_tensornds);
// restore to input comp_node // restore to input comp_node
HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0]) return HostTensorND::make_proxy(output_tensornds[0]).proxy_to_comp_node(inputs[0]->comp_node());
.proxy_to_comp_node(inputs[0]->comp_node()); }
return {Tensor::make(std::move(host_tensornd))};
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))};
} }
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
...@@ -142,6 +142,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -142,6 +142,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
HostTensorND tensor = get_var_shape_host_tensor(def, inputs);
SmallVector<MemoryDesc> ret;
auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor);
if (blob) {
ret.push_back({tensor.layout(), 0, inputs[0]->comp_node(),
StorageIdentifier::make(Tensor::make(std::forward<decltype(blob)>(blob), tensor.layout(), tensor))});
} else {
ret.push_back({tensor.layout(), 0, inputs[0]->comp_node(), StorageIdentifier::make(1)});
}
return {ret, {}};
}
void execute(const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs, const SmallVector<TensorPtr>& workspace) {
HostTensorND tensor = get_var_shape_host_tensor(def, inputs);
SmallVector<MemoryDesc> ret;
auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor);
if (!blob || blob->storage() != outputs[0]->blob()->storage()) {
outputs[0]->dev_tensor().copy_from_fixlayout(tensor);
AsyncReleaser::inst()->add(tensor);
}
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::GetVarShape>(); auto* node = &node_->cast_final_safe<opr::GetVarShape>();
return GetVarShape::make(node->param()); return GetVarShape::make(node->param());
...@@ -154,6 +181,8 @@ OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) ...@@ -154,6 +181,8 @@ OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd) .apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback(); .fallback();
} // get_var_shape } // get_var_shape
...@@ -181,6 +210,31 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( ...@@ -181,6 +210,31 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
return opr; return opr;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> param_pack_split_infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& param = def.cast_final_safe<ParamPackSplit>();
mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp->layout();
mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
mgb_assert(param.shapes.size() * 2 == param.offsets.size());
SmallVector<MemoryDesc> ret;
auto&& shapes = get_shapes(param.shapes);
size_t dtype_size = inputs[0]->layout().dtype.size();
for (size_t i = 0; i < shapes.size(); ++i) {
// memory forward
ret.push_back({{shapes[i], inputs[0]->dtype()}, param.offsets[i * 2] * dtype_size, inp->comp_node(), StorageIdentifier::make(&inputs_mems[0])});
}
return {ret, {}};
}
void param_pack_split_execute(const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs, const SmallVector<TensorPtr>& workspace) {
// do nothing
}
SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
const OpDef& def, const OpDef& def,
const SmallVector<TensorPtr>& inputs) { const SmallVector<TensorPtr>& inputs) {
...@@ -203,6 +257,8 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( ...@@ -203,6 +257,8 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit) OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
.apply_on_var_node(param_pack_split_apply_on_var_node) .apply_on_var_node(param_pack_split_apply_on_var_node)
.infer_output_mem_desc(param_pack_split_infer_output_mem_desc)
.execute(param_pack_split_execute)
.apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor) .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
.fallback(); .fallback();
...@@ -219,6 +275,64 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( ...@@ -219,6 +275,64 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
return opr; return opr;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> param_pack_concat_infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
auto dtype = inputs.front()->dtype();
size_t nr_inputs = inputs.size() - 1;
size_t nr_elems = 0;
for (size_t i = 0; i < nr_inputs; ++i) {
auto& input = inputs[i];
mgb_assert(comp_node == input->comp_node(), "inputs for param_pack_concat must in same comp_node");
mgb_assert(dtype == input->dtype(), "inputs for param_pack_concat must have same dtype");
nr_elems += input->layout().total_nr_elems();
}
auto dest_layout = TensorLayout({nr_elems}, dtype);
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t ws_size;
{
TensorShapeArray src_shapes;
for (size_t i = 0; i < nr_inputs; ++i) {
src_shapes.push_back(inputs[i]->shape());
}
ws_size = caller.op->get_workspace_in_bytes(src_shapes, inputs.back()->shape(), TensorShape{});
}
SmallVector<MemoryDesc> outputs = {{dest_layout, 0, comp_node, StorageIdentifier::make(1)}};
MemoryDesc workspace = {{{ws_size}, dtype::Byte()}, 0, comp_node, StorageIdentifier::make(2)};
return {outputs, {workspace}};
}
void param_pack_concat_execute(const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs, const SmallVector<TensorPtr>& workspace) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
size_t nr_inputs = inputs.size() - 1;
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t srcs_size = sizeof(void*)*nr_inputs;
void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
std::shared_ptr<dt_byte> srcs_ptr = {(dt_byte*)srcs_raw_ptr, [comp_node](dt_byte* ptr){
comp_node.free_host(ptr);
}};
TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
for (size_t i = 0; i < nr_inputs; ++i) {
srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr;
}
HostTensorStorage srcs_storage;
srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
megdnn::Workspace dnn_wk(workspace[0]->blob()->storage().get(), workspace[0]->blob()->size());
caller.op->exec({srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), outputs[0]->dev_tensor().as_megdnn(),
dnn_wk);
AsyncReleaser::inst()->add(HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
}
SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
const OpDef& def, const OpDef& def,
const SmallVector<TensorPtr>& inputs) { const SmallVector<TensorPtr>& inputs) {
...@@ -264,6 +378,8 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( ...@@ -264,6 +378,8 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.apply_on_var_node(param_pack_concat_apply_on_var_node) .apply_on_var_node(param_pack_concat_apply_on_var_node)
.infer_output_mem_desc(param_pack_concat_infer_output_mem_desc)
.execute(param_pack_concat_execute)
.apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor) .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
.fallback(); .fallback();
} // param_pack } // param_pack
......
...@@ -77,149 +77,6 @@ public: ...@@ -77,149 +77,6 @@ public:
bool CompNodeSyncManager::is_into_atexit = false; bool CompNodeSyncManager::is_into_atexit = false;
#endif #endif
// Cache for small blobs
// 1. A blob has to be seen twice (within a window) to be eligible for cache
// 2. Cache eviction occurs when cache size reaches a threshold, in least frequently used order
class ConstTensorCache {
public:
struct Entry {
size_t hitcnt = 0;
std::unique_ptr<dt_byte[]> data;
size_t size;
BlobPtr blob;
Entry() = default;
Entry(const dt_byte* ptr, size_t size_, BlobPtr blob_)
: data(new dt_byte[size_]), size(size_), blob(blob_) {
memcpy(data.get(), ptr, size);
}
// does not check input
bool match(const HostTensorND& hv) {
return 0 == memcmp(data.get(), hv.raw_ptr(), hv.layout().span().high_byte);
}
};
using KV = std::pair<uint64_t, Entry>;
bool check(const HostTensorND& hv) {
auto&& layout = hv.layout();
auto&& span = layout.span();
return hv.format().is_default() && !hv.empty() &&
layout.is_contiguous() && span.low_byte == 0 &&
span.high_byte <= max_bytes;
}
// hash storage; does not check input
static uint64_t hash(const HostTensorND& hv) {
auto&& span = hv.layout().span();
return XXHash{}
.update(hv.raw_ptr(), span.high_byte)
.digest();
}
BlobPtr lookup(const HostTensorND& hv) {
if (!check(hv)) {
return {};
}
auto h = hash(hv);
MGB_LOCK_GUARD(mtx);
// lookup in g1
auto it = g1.find(h);
if (it != g1.end()) {
if (!it->second.match(hv)) {
mgb_log_warn("hash collision in const tensor cache");
return {};
}
it->second.hitcnt += 1;
return it->second.blob;
}
// lookup in g0
if (!g0.extract(h) && !g0b.extract(h)) {
maybe_collect_g0();
g0.emplace(h);
return {};
}
// add new entry to g1
maybe_collect_g1();
Entry entry(hv.raw_ptr(), hv.layout().span().high_byte, Tensor(hv).blob());
it = g1.emplace_hint(it, h, std::move(entry));
it->second.hitcnt += 1;
return it->second.blob;
}
void clear() {
MGB_LOCK_GUARD(mtx);
g0.clear();
g0b.clear();
g1.clear();
}
std::mutex mtx;
const size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536;
private:
void maybe_collect_g0() {
if (g0.size() > window) {
std::swap(g0, g0b);
g0.clear();
}
}
void maybe_collect_g1() {
if (g1.size() < hwm) return;
tmp.clear();
for (auto&& kv : g1) {
tmp.emplace_back(kv.first, std::move(kv.second));
}
std::nth_element(tmp.begin(), tmp.begin() + lwm, tmp.end(), [](const KV& lhs, const KV& rhs) {
return lhs.second.hitcnt > rhs.second.hitcnt;
});
tmp.resize(lwm);
g1.clear();
for (auto&& kv : tmp) {
kv.second.hitcnt = 0;
g1.emplace(std::move(kv));
}
}
// g0: records blobs which have been seen at least once (within a window)
// g0b: backup of g0
// g1: records the most frequently used blobs which have been seen at least
// twice. When `g1.size() == hwm`, it will be refreshed and only the top
// `lhw` frequently used blobs will be kept.
std::unordered_set<uint64_t> g0, g0b;
std::unordered_map<uint64_t, Entry> g1;
std::vector<KV> tmp;
public:
ConstTensorCache() {
g0.reserve(window), g0b.reserve(window);
g1.reserve(hwm), tmp.reserve(hwm);
}
};
struct MultiCNConstTensorCache : CompNodeDepedentObject {
std::mutex mtx;
CompNode::UnorderedMap<ConstTensorCache> cn2cache;
std::shared_ptr<void> on_comp_node_finalize() {
MGB_LOCK_GUARD(mtx);
cn2cache.clear();
return {};
}
BlobPtr lookup(const HostTensorND& hv) {
MGB_LOCK_GUARD(mtx);
return cn2cache[hv.comp_node()].lookup(hv);
}
static MultiCNConstTensorCache& inst() {
static MultiCNConstTensorCache sl_inst;
return sl_inst;
}
};
} // namespace } // namespace
void EventDeleter::operator()(CompNode::Event* event) { void EventDeleter::operator()(CompNode::Event* event) {
......
...@@ -522,9 +522,10 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs( ...@@ -522,9 +522,10 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs(
void ProxyGraph::invoke_op(const OpDef& opdef, void ProxyGraph::invoke_op(const OpDef& opdef,
const SmallVector<Tensor*>& inputs, const SmallVector<Tensor*>& inputs,
const SmallVector<Tensor*>& outputs) { const SmallVector<Tensor*>& outputs,
const SmallVector<Tensor*>& workspaces) {
CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); CUR_OPR_GUARD(get_proxy_opr(opdef, inputs));
init_output_tensor(outputs); init_output_tensor(outputs, workspaces);
for (auto oup : m_cur_opr->output()) { for (auto oup : m_cur_opr->output()) {
m_graph->add_used_comp_node(oup->comp_node()); m_graph->add_used_comp_node(oup->comp_node());
} }
...@@ -544,19 +545,30 @@ void ProxyGraph::cleanup() { ...@@ -544,19 +545,30 @@ void ProxyGraph::cleanup() {
m_cur_opr = nullptr; m_cur_opr = nullptr;
} }
void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs) { void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs, const SmallVector<Tensor*>& workspaces) {
// get proxy opr // get proxy opr
auto proxy = m_cur_opr; auto proxy = m_cur_opr;
do_shape_infer(true); do_shape_infer(true);
size_t j = 0; size_t j = 0;
size_t k = 0;
for (auto&& var : proxy->output()) { for (auto&& var : proxy->output()) {
auto &&chk = var->m_mem_plan.reset_from_owner_var().chunk(); auto &&chk = var->m_mem_plan.reset_from_owner_var().chunk();
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
// alloc workspace // workspace
if (workspaces.size()) {
mgb_assert(k < workspaces.size());
auto && layout = workspaces[k]->layout();
mgb_assert(var->comp_node() == workspaces[k]->comp_node() &&
var->shape().eq_shape(layout) &&
var->dtype() == layout.dtype);
var->m_dev_tensor = workspaces[k]->dev_tensor();
++ k;
} else {
TensorLayout layout{var->shape(), var->dtype(), var->format()}; TensorLayout layout{var->shape(), var->dtype(), var->format()};
var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout); var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout);
}
} else { } else {
mgb_assert(j < outputs.size()); mgb_assert(j < outputs.size());
auto &&tensor = outputs[j]; auto &&tensor = outputs[j];
...@@ -570,6 +582,7 @@ void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs) { ...@@ -570,6 +582,7 @@ void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs) {
chk.mem_alloc_status.set_from_owner_var(); chk.mem_alloc_status.set_from_owner_var();
} }
mgb_assert(j == outputs.size()); mgb_assert(j == outputs.size());
mgb_assert(k == workspaces.size());
// Memory forwarding was bypassed in megbrain with graph option // Memory forwarding was bypassed in megbrain with graph option
// imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly // imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly
...@@ -623,6 +636,26 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::infer_output_attrs_ ...@@ -623,6 +636,26 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::infer_output_attrs_
return {outputs, validated && !need_check}; return {outputs, validated && !need_check};
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::infer_output_mem_desc(
const OpDef& def,
const SmallVector<Tensor*>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto opr = get_proxy_opr(def, inputs_tensors);
CUR_OPR_GUARD(opr);
do_shape_infer(true);
SmallVector<MemoryDesc> outputs;
SmallVector<MemoryDesc> workspaces;
size_t cur_id = 0;
for (auto&& i : opr->output()) {
if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
workspaces.push_back({{i->shape(), i->dtype(), i->format()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)});
} else {
outputs.push_back({{i->shape(), i->dtype()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)});
}
}
return {outputs, workspaces};
}
struct ProxyGraph::GradGraph { struct ProxyGraph::GradGraph {
cg::VarNodeArray inputs; cg::VarNodeArray inputs;
cg::VarNodeArray outputs; cg::VarNodeArray outputs;
......
...@@ -37,7 +37,8 @@ public: ...@@ -37,7 +37,8 @@ public:
void invoke_op( void invoke_op(
const OpDef& opdef, const OpDef& opdef,
const SmallVector<Tensor*>& inputs, const SmallVector<Tensor*>& inputs,
const SmallVector<Tensor*>& outputs); const SmallVector<Tensor*>& outputs,
const SmallVector<Tensor*>& workspace);
BackwardGraphResult make_backward_graph( BackwardGraphResult make_backward_graph(
const OpDef& opdef, const OpDef& opdef,
...@@ -45,6 +46,11 @@ public: ...@@ -45,6 +46,11 @@ public:
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad); const SmallVector<bool>& output_has_grad);
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<Tensor*>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
/********************** Logical Tensor API **********************/ /********************** Logical Tensor API **********************/
size_t get_opr_output_size( size_t get_opr_output_size(
...@@ -74,7 +80,8 @@ private: ...@@ -74,7 +80,8 @@ private:
void cleanup(); void cleanup();
void init_output_tensor( void init_output_tensor(
const SmallVector<Tensor*>& outputs); const SmallVector<Tensor*>& outputs,
const SmallVector<Tensor*>& workspace);
cg::OperatorNodeBase* get_proxy_opr( cg::OperatorNodeBase* get_proxy_opr(
const OpDef& opdef, const OpDef& opdef,
......
...@@ -43,10 +43,12 @@ infer_output_attrs(const OpDef& def, ...@@ -43,10 +43,12 @@ infer_output_attrs(const OpDef& def,
void exec(const OpDef& def, void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) { const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspaces) {
auto&& graph = ProxyGraph::get_default_graph(); auto&& graph = ProxyGraph::get_default_graph();
auto raw_inputs = to_raw_ptr_array(inputs), auto raw_inputs = to_raw_ptr_array(inputs),
raw_outputs = to_raw_ptr_array(outputs); raw_outputs = to_raw_ptr_array(outputs),
raw_workspaces = to_raw_ptr_array(workspaces);
CompNode::UnorderedSet used_cns; CompNode::UnorderedSet used_cns;
for (auto&& out: raw_outputs) { for (auto&& out: raw_outputs) {
auto cn = out->comp_node(); auto cn = out->comp_node();
...@@ -59,7 +61,7 @@ void exec(const OpDef& def, ...@@ -59,7 +61,7 @@ void exec(const OpDef& def,
} }
} }
} }
graph->invoke_op(def, raw_inputs, raw_outputs); graph->invoke_op(def, raw_inputs, raw_outputs, raw_workspaces);
for (auto&& cn: used_cns) { for (auto&& cn: used_cns) {
for (auto&& in: inputs) { for (auto&& in: inputs) {
if (in->comp_node() != cn) { if (in->comp_node() != cn) {
...@@ -77,7 +79,7 @@ apply_on_physical_tensor(const OpDef& def, ...@@ -77,7 +79,7 @@ apply_on_physical_tensor(const OpDef& def,
for (size_t i = 0; i < outputs.size(); i++) { for (size_t i = 0; i < outputs.size(); i++) {
outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
} }
exec(def, inputs, outputs); exec(def, inputs, outputs, {});
auto async_error = ProxyGraph::get_async_error(); auto async_error = ProxyGraph::get_async_error();
if (async_error) { if (async_error) {
throw *async_error; throw *async_error;
...@@ -85,6 +87,26 @@ apply_on_physical_tensor(const OpDef& def, ...@@ -85,6 +87,26 @@ apply_on_physical_tensor(const OpDef& def,
return outputs; return outputs;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_mem_desc(def, to_raw_ptr_array(inputs_tensors), inputs_mems);
}
void execute(const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
exec(def, inputs, outputs, workspace);
auto async_error = ProxyGraph::get_async_error();
if (async_error) {
throw *async_error;
}
return;
}
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, // std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
// const SmallVector<LogicalTensorDesc>& inputs) { // const SmallVector<LogicalTensorDesc>& inputs) {
// auto&& graph = ProxyGraph::get_default_graph(); // auto&& graph = ProxyGraph::get_default_graph();
......
...@@ -109,6 +109,13 @@ public: ...@@ -109,6 +109,13 @@ public:
const OpDef& def, const OpDef& def,
SmallVector<TensorPtr> inputs); SmallVector<TensorPtr> inputs);
static void execute(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace);
/*! /*!
* \brief Call the corresponding dnn op to calculate results. Output * \brief Call the corresponding dnn op to calculate results. Output
* tensors' device memory should be allocated outside. * tensors' device memory should be allocated outside.
...@@ -126,6 +133,11 @@ public: ...@@ -126,6 +133,11 @@ public:
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);
static std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
static BackwardGraphResult make_backward_graph( static BackwardGraphResult make_backward_graph(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs, const SmallVector<LogicalTensorDesc>& inputs,
......
...@@ -150,12 +150,192 @@ private: ...@@ -150,12 +150,192 @@ private:
EventPtr m_value_ready = nullptr; EventPtr m_value_ready = nullptr;
}; };
// Cache for small blobs
// 1. A blob has to be seen twice (within a window) to be eligible for cache
// 2. Cache eviction occurs when cache size reaches a threshold, in least frequently used order
class ConstTensorCache {
public:
struct Entry {
size_t hitcnt = 0;
std::unique_ptr<dt_byte[]> data;
size_t size;
BlobPtr blob;
Entry() = default;
Entry(const dt_byte* ptr, size_t size_, BlobPtr blob_)
: data(new dt_byte[size_]), size(size_), blob(blob_) {
memcpy(data.get(), ptr, size);
}
// does not check input
bool match(const HostTensorND& hv) {
return 0 == memcmp(data.get(), hv.raw_ptr(), hv.layout().span().high_byte);
}
};
using KV = std::pair<uint64_t, Entry>;
bool check(const HostTensorND& hv) {
auto&& layout = hv.layout();
auto&& span = layout.span();
return hv.format().is_default() && !hv.empty() &&
layout.is_contiguous() && span.low_byte == 0 &&
span.high_byte <= max_bytes;
}
// hash storage; does not check input
static uint64_t hash(const HostTensorND& hv) {
auto&& span = hv.layout().span();
return XXHash{}
.update(hv.raw_ptr(), span.high_byte)
.digest();
}
BlobPtr lookup(const HostTensorND& hv) {
if (!check(hv)) {
return {};
}
auto h = hash(hv);
MGB_LOCK_GUARD(mtx);
// lookup in g1
auto it = g1.find(h);
if (it != g1.end()) {
if (!it->second.match(hv)) {
mgb_log_warn("hash collision in const tensor cache");
return {};
}
it->second.hitcnt += 1;
return it->second.blob;
}
// lookup in g0
if (!g0.extract(h) && !g0b.extract(h)) {
maybe_collect_g0();
g0.emplace(h);
return {};
}
// add new entry to g1
maybe_collect_g1();
Entry entry(hv.raw_ptr(), hv.layout().span().high_byte, Tensor(hv).blob());
it = g1.emplace_hint(it, h, std::move(entry));
it->second.hitcnt += 1;
return it->second.blob;
}
void clear() {
MGB_LOCK_GUARD(mtx);
g0.clear();
g0b.clear();
g1.clear();
}
std::mutex mtx;
const size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536;
private:
void maybe_collect_g0() {
if (g0.size() > window) {
std::swap(g0, g0b);
g0.clear();
}
}
void maybe_collect_g1() {
if (g1.size() < hwm) return;
tmp.clear();
for (auto&& kv : g1) {
tmp.emplace_back(kv.first, std::move(kv.second));
}
std::nth_element(tmp.begin(), tmp.begin() + lwm, tmp.end(), [](const KV& lhs, const KV& rhs) {
return lhs.second.hitcnt > rhs.second.hitcnt;
});
tmp.resize(lwm);
g1.clear();
for (auto&& kv : tmp) {
kv.second.hitcnt = 0;
g1.emplace(std::move(kv));
}
}
// g0: records blobs which have been seen at least once (within a window)
// g0b: backup of g0
// g1: records the most frequently used blobs which have been seen at least
// twice. When `g1.size() == hwm`, it will be refreshed and only the top
// `lhw` frequently used blobs will be kept.
std::unordered_set<uint64_t> g0, g0b;
std::unordered_map<uint64_t, Entry> g1;
std::vector<KV> tmp;
public:
ConstTensorCache() {
g0.reserve(window), g0b.reserve(window);
g1.reserve(hwm), tmp.reserve(hwm);
}
};
struct MultiCNConstTensorCache : CompNodeDepedentObject {
std::mutex mtx;
CompNode::UnorderedMap<ConstTensorCache> cn2cache;
std::shared_ptr<void> on_comp_node_finalize() {
MGB_LOCK_GUARD(mtx);
cn2cache.clear();
return {};
}
BlobPtr lookup(const HostTensorND& hv) {
MGB_LOCK_GUARD(mtx);
return cn2cache[hv.comp_node()].lookup(hv);
}
static MultiCNConstTensorCache& inst() {
static MultiCNConstTensorCache sl_inst;
return sl_inst;
}
};
struct LogicalTensorDesc { struct LogicalTensorDesc {
TensorLayout layout; TensorLayout layout;
CompNode comp_node; CompNode comp_node;
DeviceTensorND value; // cpu:default DeviceTensorND value; // cpu:default
}; };
struct StorageIdentifier;
struct MemoryDesc {
TensorLayout layout;
size_t offset;
CompNode cn;
std::shared_ptr<StorageIdentifier> id;
};
struct StorageIdentifier {
enum { INVALID, SYS_ALLOC, FROM_OTHER, DEVICE_PTR } tag;
union {
size_t id;
MemoryDesc* desc;
};
TensorPtr ptr;
StorageIdentifier() = default;
StorageIdentifier(size_t id): tag(SYS_ALLOC), id(id) {}
StorageIdentifier(const MemoryDesc* desc): tag(FROM_OTHER), desc(desc->id->desc) {}
StorageIdentifier(TensorPtr dev_ptr): tag(DEVICE_PTR), ptr(dev_ptr) {}
template<typename ...Args>
static std::shared_ptr<StorageIdentifier> make(Args&& ...args) {
return std::make_shared<StorageIdentifier>(std::forward<Args>(args)...);
}
bool is_sys_alloc() {
return tag == SYS_ALLOC;
}
bool is_from_other() {
return tag == FROM_OTHER;
}
bool is_device_ptr() {
return tag == DEVICE_PTR;
}
bool is_invalid() {
return tag == INVALID;
}
};
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -21,9 +21,19 @@ SmallVector<TensorPtr> ...@@ -21,9 +21,19 @@ SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def, apply_on_physical_tensor(const OpDef& def,
SmallVector<TensorPtr> inputs); SmallVector<TensorPtr> inputs);
void execute(const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace);
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
void exec(const OpDef& def, void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs); const SmallVector<TensorPtr>& outputs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册