提交 aa587446 编写于 作者: M Megvii Engine Team

feat(subgraph): support shape inference for CompiledOp

GitOrigin-RevId: a96b8f344673eef63126bfc6122e55a5dd2eac58
上级 1c1e9b00
......@@ -18,6 +18,7 @@
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
......@@ -289,24 +290,232 @@ struct DeviceMemoryAllocatorImpl : cg::DeviceMemoryAllocator {
}
};
enum class HolderKind {
ShapeInfer,
Execute,
};
template <HolderKind Kind>
struct ComputingGraphHolder {
struct Input {
std::shared_ptr<DeviceTensorND> device_value;
std::shared_ptr<HostTensorND> host_value;
std::shared_ptr<HostTensorND> host_shape;
};
std::shared_ptr<ComputingGraph> graph;
std::unique_ptr<cg::AsyncExecutable> executable;
SmallVector<std::shared_ptr<DeviceTensorND>> inputs;
SmallVector<std::shared_ptr<DeviceTensorND>> outputs;
SmallVector<Input> inputs;
SmallVector<std::shared_ptr<DeviceTensorND>> device_outputs;
SmallVector<VarNode*> input_vars;
SmallVector<VarNode*> output_vars;
std::shared_ptr<DeviceMemoryAllocatorImpl> allocator;
SmallVector<std::unique_ptr<CompNode::Event>> events;
std::unique_ptr<cg::static_infer::StaticInferUpdater> updater;
void initialize(
const CompiledOp& op, const SmallVector<LogicalTensorDesc>& input_descs) {
allocator = std::make_shared<DeviceMemoryAllocatorImpl>();
graph = ComputingGraph::make();
graph->options().force_dynamic_alloc = true;
graph->options().async_exec_level = 0;
graph->options().graph_opt_level = op.gopt_level;
graph->options().enable_var_mem_defragment = false;
graph->options().comp_seq_sync_device = false;
// set allocator for DTR support
graph->set_device_memory_allocator(allocator);
if constexpr (Kind == HolderKind::ShapeInfer) {
updater = cg::static_infer::StaticInferUpdater::make();
}
for (auto&& desc : input_descs) {
Input input;
VarNode* input_var = nullptr;
if constexpr (Kind == HolderKind::Execute) {
input.device_value = std::make_shared<DeviceTensorND>();
input.device_value->dtype(desc.layout.dtype);
input.device_value->comp_node(desc.comp_node);
input.device_value->resize(desc.layout);
auto callback = [value = input.device_value] { return *value; };
if (!desc.value.empty()) {
input.host_value = std::make_shared<HostTensorND>();
input.host_value->dtype(desc.layout.dtype);
input.host_value->comp_node(desc.comp_node);
}
input_var = opr::MutableTensor::make(
*graph, input.device_value, input.host_value, {})
.node();
// input_var = opr::VolatileSharedDeviceTensor::make(*graph,
// input.device_value).node();
} else if constexpr (Kind == HolderKind::ShapeInfer) {
if (desc.value.empty()) {
input.host_shape = std::make_shared<HostTensorND>();
input.host_shape->dtype(dtype::Int32());
input.host_shape->comp_node(desc.comp_node);
auto input_shape_var =
opr::Host2DeviceCopy::make(*graph, input.host_shape);
input_var =
opr::Alloc::make(input_shape_var, desc.layout.dtype).node();
} else {
input.host_value = std::make_shared<HostTensorND>();
input.host_value->dtype(desc.layout.dtype);
input.host_value->comp_node(desc.comp_node);
input_var =
opr::Host2DeviceCopy::make(*graph, input.host_value).node();
}
} else {
static_assert((Kind != Kind), "unknown holder kind");
}
input_vars.push_back(input_var);
inputs.push_back(input);
}
// forward to inner op
output_vars = OpDef::apply_on_var_node(*op.op, input_vars);
ComputingGraph::OutputSpec output_spec;
CompNode::UnorderedSet comp_nodes;
for (auto&& output_var : output_vars) {
using namespace cg::static_infer;
auto output_ptr = std::make_shared<DeviceTensorND>();
auto callback = [output_ptr](DeviceTensorND output) {
output_ptr->reset(output.storage(), output.layout());
output = {};
};
if constexpr (Kind == HolderKind::ShapeInfer) {
output_spec.push_back({output_var, callback});
auto it = graph->static_infer_manager().get_infer_type(output_var);
if (it.shape == InferType::RT_STATIC) {
updater->add_dest({output_var, DepType::SHAPE});
}
if (it.value == InferType::RT_STATIC) {
updater->add_dest({output_var, DepType::VALUE});
}
} else {
auto output_callback_var =
opr::OutputCallback::make({callback}, output_var);
output_spec.push_back({output_callback_var, {}});
}
device_outputs.push_back(output_ptr);
}
executable = graph->compile(output_spec);
executable->iter_opr_seq([&](cg::OperatorNodeBase* opr) -> bool {
for (auto&& output : opr->output()) {
comp_nodes.insert(output->comp_node());
}
return true;
});
for (auto&& comp_node : comp_nodes) {
events.push_back(comp_node.create_event());
events.back()->record();
}
}
template <
HolderKind ThisKind = Kind,
typename = std::enable_if_t<ThisKind == HolderKind::Execute>>
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<LogicalTensorDesc> input_descs,
const SmallVector<TensorPtr>& input_tensors) {
// wait for last execution
executable->wait();
size_t nr_inputs = inputs.size();
for (size_t i = 0; i < nr_inputs; ++i) {
auto input_dev_tensor = input_tensors[i]->dev_tensor();
inputs[i].device_value->reset(
input_dev_tensor.storage(), input_dev_tensor.layout());
if (inputs[i].host_value) {
inputs[i].host_value->copy_from(input_descs[i].value);
}
}
allocator->current_op = const_cast<OpDef&>(def).shared_from_this();
executable->execute();
for (auto&& event : events) {
event->record();
}
SmallVector<TensorPtr> outputs_tensors;
for (auto input : inputs) {
*input.device_value = {};
if (input.host_value) {
*input.host_value = {};
}
}
for (auto output_nd : device_outputs) {
outputs_tensors.push_back(Tensor::make(*output_nd));
*output_nd = {};
}
executable->clear_device_memory();
allocator->current_op = nullptr;
return outputs_tensors;
}
template <
HolderKind ThisKind = Kind,
typename = std::enable_if_t<ThisKind == HolderKind::ShapeInfer>>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
executable->wait();
size_t nr_inputs = input_vars.size(), nr_outputs = output_vars.size();
SmallVector<LogicalTensorDesc> output_descs(nr_outputs);
for (size_t i = 0; i < nr_inputs; ++i) {
if (inputs[i].host_shape) {
DeviceTensorND input_shape_device_nd;
cg::copy_shape_to_tensor_value(
input_shape_device_nd, input_descs[i].layout);
inputs[i].host_shape->copy_from(input_shape_device_nd);
mgb_assert(input_descs[i].layout.ndim, "ndim == 0");
} else if (inputs[i].host_value) {
inputs[i].host_value->copy_from(input_descs[i].value);
}
}
updater->update();
bool validated = true;
for (size_t i = 0; i < nr_outputs; ++i) {
auto infer_type =
graph->static_infer_manager().get_infer_type(output_vars[i]);
const TensorShape* output_shape = nullptr;
const DeviceTensorND* output_value = nullptr;
auto& desc = output_descs[i];
if (infer_type.shape != cg::static_infer::InferType::NO_DESC) {
output_shape = graph->static_infer_manager().infer_shape_fallible(
output_vars[i]);
}
if (infer_type.value != cg::static_infer::InferType::NO_DESC) {
output_value = graph->static_infer_manager().infer_value_fallible(
output_vars[i]);
}
if (output_shape && output_value) {
mgb_assert(
output_shape->eq_shape(output_value->shape()),
"shape infer result mismatch, %s vs %s",
output_shape->to_string().c_str(),
output_value->shape().to_string().c_str());
}
if (output_shape) {
((TensorShape&)desc.layout) = *output_shape;
}
if (output_value) {
((TensorShape&)desc.layout) = output_value->shape();
desc.value = *output_value;
}
desc.layout.dtype = output_vars[i]->dtype();
desc.comp_node = output_vars[i]->comp_node();
if (!desc.layout.ndim) {
validated = false;
}
desc.layout.init_contiguous_stride();
}
return {output_descs, validated};
}
};
ComputingGraphHolder& get_computing_graph(
std::shared_ptr<OpDef> compiled_op, SmallVector<LogicalTensorDesc> descs) {
template <HolderKind Kind>
ComputingGraphHolder<Kind>& get_computing_graph(
std::shared_ptr<OpDef> compiled_op,
const SmallVector<LogicalTensorDesc>& descs) {
using ComputingGraphHolderCache =
OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder>>>;
OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder<Kind>>>>;
thread_local auto cache = std::make_unique<ComputingGraphHolderCache>();
thread_local size_t nr_cg_holders = 0;
ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
auto& cg_holder_queue = (*cache)[cache_key];
std::unique_ptr<ComputingGraphHolder> holder;
std::unique_ptr<ComputingGraphHolder<Kind>> holder;
if (!cg_holder_queue.empty()) {
// pick one
std::swap(cg_holder_queue.front(), holder);
......@@ -326,6 +535,8 @@ ComputingGraphHolder& get_computing_graph(
"for prev graph");
event->host_wait();
}
} else {
event->host_wait();
}
}
if (holder) {
......@@ -334,54 +545,9 @@ ComputingGraphHolder& get_computing_graph(
}
if (!holder) {
// create new computing graph
holder = std::make_unique<ComputingGraphHolder>();
holder = std::make_unique<ComputingGraphHolder<Kind>>();
auto& cg_holder = *holder;
cg_holder.allocator = std::make_shared<DeviceMemoryAllocatorImpl>();
cg_holder.graph = ComputingGraph::make();
cg_holder.graph->options().force_dynamic_alloc = true;
cg_holder.graph->options().async_exec_level = 0;
cg_holder.graph->options().graph_opt_level =
compiled_op->cast_final_safe<CompiledOp>().gopt_level;
cg_holder.graph->options().enable_var_mem_defragment = false;
cg_holder.graph->options().comp_seq_sync_device = false;
// set allocator for DTR support
cg_holder.graph->set_device_memory_allocator(cg_holder.allocator);
VarNodeArray input_vars;
for (auto&& desc : descs) {
auto input_device_nd = std::make_shared<DeviceTensorND>();
input_device_nd->dtype(desc.layout.dtype);
input_device_nd->comp_node(desc.comp_node);
input_device_nd->resize(desc.layout);
cg_holder.inputs.push_back(input_device_nd);
auto callback = [input_device_nd] { return *input_device_nd; };
auto* input_var = opr::InputCallback::make(
*cg_holder.graph, callback, desc.comp_node,
desc.layout.dtype, TensorShape())[0]
.node();
input_vars.push_back(input_var);
}
// forward to inner op
auto output_vars = OpDef::apply_on_var_node(*compiled_op, input_vars);
ComputingGraph::OutputSpec output_spec;
size_t nr_outputs = output_vars.size();
for (size_t i = 0; i < nr_outputs; ++i) {
auto* output_var = output_vars[i];
auto output_ptr = std::make_shared<DeviceTensorND>();
auto callback = [output_ptr](DeviceTensorND output) {
output_ptr->reset(output.storage(), output.layout());
};
output_spec.push_back({output_var, callback});
cg_holder.outputs.push_back(output_ptr);
}
cg_holder.executable = cg_holder.graph->compile(output_spec);
CompNode::UnorderedSet comp_nodes;
for (auto&& output_var : output_vars) {
comp_nodes.insert(output_var->comp_node());
}
for (auto&& comp_node : comp_nodes) {
cg_holder.events.push_back(comp_node.create_event());
cg_holder.events.back()->record();
}
cg_holder.initialize(compiled_op->cast_final_safe<CompiledOp>(), descs);
nr_cg_holders++;
mgb_log_debug(
"add new computing graph for compiled op, now %zu graphs",
......@@ -395,34 +561,18 @@ auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& in
SmallVector<LogicalTensorDesc> input_descs;
for (auto&& input : inputs) {
input_descs.push_back({input->layout(), input->comp_node()});
if (auto* host_value = input->try_get_value()) {
if (host_value->layout().total_nr_elems() <=
MEGDNN_MAX_NDIM) { // infer small tensor
input_descs.back().value = host_value->proxy_to_default_cpu();
}
}
}
size_t nr_inputs = inputs.size();
auto shared_def = const_cast<OpDef&>(def).shared_from_this();
auto& cg_holder = get_computing_graph(shared_def, input_descs);
// wait for last execution
cg_holder.executable->wait();
for (size_t i = 0; i < nr_inputs; ++i) {
auto input_dev_tensor = inputs[i]->dev_tensor();
cg_holder.inputs[i]->reset(
input_dev_tensor.storage(), input_dev_tensor.layout());
}
cg_holder.allocator->current_op = shared_def;
cg_holder.executable->execute();
for (auto&& event : cg_holder.events) {
event->record();
}
SmallVector<TensorPtr> outputs;
for (auto input_nd : cg_holder.inputs) {
*input_nd = {};
}
for (auto output_nd : cg_holder.outputs) {
outputs.push_back(Tensor::make(*output_nd));
*output_nd = {};
}
cg_holder.executable->clear_device_memory();
cg_holder.allocator->current_op = nullptr;
return outputs;
auto& cg_holder = get_computing_graph<HolderKind::Execute>(shared_def, input_descs);
return cg_holder.apply_on_physical_tensor(def, input_descs, inputs);
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto& op = def.cast_final_safe<CompiledOp>();
op.op->set_scope(op.scope());
......@@ -430,9 +580,28 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
}
auto infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
return OpDef::infer_output_attrs_fallible(
*def.cast_final_safe<CompiledOp>().op, input_descs);
const OpDef& def, SmallVector<LogicalTensorDesc> input_descs) {
bool shape_all_valid = true;
for (auto&& input_desc : input_descs) {
if (!input_desc.layout.ndim) {
shape_all_valid = false;
break;
}
}
if (!shape_all_valid) {
return OpDef::infer_output_attrs_fallible(
*def.cast_final_safe<CompiledOp>().op, input_descs);
}
auto shared_def = const_cast<OpDef&>(def).shared_from_this();
for (auto& input_desc : input_descs) {
if (input_desc.layout.total_nr_elems() >
MEGDNN_MAX_NDIM) { // skip large tensor
input_desc.value = {};
}
}
auto& cg_holder =
get_computing_graph<HolderKind::ShapeInfer>(shared_def, input_descs);
return cg_holder.infer_output_attrs_fallible(def, input_descs);
}
auto props(const OpDef& def) {
......@@ -453,19 +622,8 @@ EncodedSubgraph make_backward_graph(
auto backward_graph = OpDef::make_backward_graph(
*op.op, inputs, input_requires_grad, output_has_grad);
auto name = def.trait()->make_name(def);
auto key = std::make_shared<BackwardOpKey>();
key->op = op.op;
key->inputs = inputs;
key->extras = {input_requires_grad, output_has_grad};
SmallVector<bool> grad_outputs_has_grad(backward_graph.graph.outputs.size(), true);
std::shared_ptr<OpDef> bgraph_op;
if (backward_graph.graph.is_single()) {
bgraph_op = backward_graph.graph.as_single();
} else {
bgraph_op = SubgraphOp::make(
name + "Grad", std::make_shared<Subgraph>(backward_graph.graph),
grad_outputs_has_grad, key);
}
std::shared_ptr<OpDef> bgraph_op =
SubgraphOp::wrap(name + "Grad", backward_graph.graph);
auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level);
auto encoded_graph = EncodedSubgraph::make_single(
compiled_op, backward_graph.input_mask, backward_graph.output_mask);
......
......@@ -76,6 +76,13 @@ struct SubgraphOp final : OpDefImplBase<SubgraphOp> {
this->output_grad_mask.resize(graph->outputs.size(), true);
}
}
static std::shared_ptr<OpDef> wrap(std::string name, Subgraph graph) {
if (graph.is_single()) {
return graph.as_single();
} else {
return SubgraphOp::make(name, std::make_shared<Subgraph>(graph));
}
}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
......
......@@ -259,7 +259,7 @@ public:
*/
class StaticInferUpdater : public NonCopyableObj {
public:
static std::unique_ptr<StaticInferUpdater> make();
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<StaticInferUpdater> make();
virtual ~StaticInferUpdater() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册