提交 634de590 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add valid flag of `infer_output_attrs_fallible`

GitOrigin-RevId: b2b32774eeb893503c25d3434fa6f2ba64f1c8c6
上级 50c4daac
......@@ -63,15 +63,17 @@ SmallVector<void*> ChannelImpl::apply_op(
input_infos.push_back(info);
input_descs.push_back(info->desc);
}
auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs);
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
ApplyOp cmd{std::move(op)};
cmd.inputs = std::move(input_infos);
cmd.outputs.reserve(output_descs.size());
SmallVector<void*> outputs;
bool is_fallible = false;
// FIXME: remove this check when op check is correct
bool validated_bkp = true;
for (auto&& desc : output_descs) {
if (desc.layout.ndim == 0) {
is_fallible = true;
validated_bkp = false;
}
auto info = alloc();
info->desc = desc;
......@@ -80,8 +82,14 @@ SmallVector<void*> ChannelImpl::apply_op(
outputs.push_back(info);
}
m_worker.add_task(std::move(cmd));
if (is_fallible && m_async_level <= 1) {
if (!(validated && validated_bkp) && m_async_level == 1) {
sync();
} else if (m_async_level == 0) {
sync();
// check device error
for (auto&& oup : cmd.outputs) {
oup->ptr->comp_node().sync();
}
}
return outputs;
}
......@@ -194,6 +202,9 @@ ChannelImpl::~ChannelImpl() {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
MGB_LOCK_GUARD(m_mutex);
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
if (m_waitee == dest) {
m_cv.notify_all();
......
......@@ -42,7 +42,7 @@ cg::OperatorNodeBase* OpDef::apply_on_var_node(
return def.trait()->apply_on_var_node(def, inputs);
}
SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible(
std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
return def.trait()->infer_output_attrs_fallible(def, inputs);
......
......@@ -24,12 +24,12 @@ BackwardGraph::InternalGraph::apply(
inputs);
}
SmallVector<LogicalTensorDesc>
BackwardGraph::InternalGraph::infer_attrs(
std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::infer_attrs(
const SmallVector<LogicalTensorDesc>& inputs) const {
using TensorAttr = LogicalTensorDesc;
ThinHashMap<size_t, TensorAttr> node2attr;
auto&& input_nodes = this->inputs;
auto&& output_nodes = this->outputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2attr[input_nodes[i]] = inputs[i];
......@@ -41,25 +41,29 @@ BackwardGraph::InternalGraph::infer_attrs(
i.second->layout(), i.second->comp_node(),
value->proxy_to_default_cpu()};
}
bool validated = true;
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& expr = exprs[i];
SmallVector<TensorAttr> inputs;
for (auto &&in : std::get<1>(expr)) {
inputs.push_back(node2attr.at(in));
auto&& [expr_op, expr_inps, expr_oups] = exprs[i];
SmallVector<TensorAttr> expr_input_descs;
for (auto &&inp : expr_inps) {
expr_input_descs.push_back(node2attr.at(inp));
}
auto outputs = OpDef::infer_output_attrs_fallible(
*std::get<0>(expr), inputs);
auto output_nodes = std::get<2>(expr);
mgb_assert(outputs.size() == output_nodes.size());
for (size_t i = 0; i < outputs.size(); ++ i) {
node2attr[output_nodes[i]] = outputs[i];
auto[expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible(
*expr_op, expr_input_descs);
validated = validated && expr_validated;
mgb_assert(expr_output_descs.size() == expr_oups.size());
for (size_t i = 0; i < expr_output_descs.size(); ++ i) {
node2attr[expr_oups[i]] = expr_output_descs[i];
}
}
SmallVector<TensorAttr> ret;
for (auto &&i : outputs) {
for (auto &&i : output_nodes) {
ret.push_back(node2attr.at(i));
}
return ret;
return {ret, validated};
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph);
......@@ -72,11 +76,11 @@ SmallVector<TensorPtr> backward_impl(
.graph().apply(tensors);
}
SmallVector<LogicalTensorDesc> infer_tensor_attrs(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs(
const OpDef& backward_graph,
const SmallVector<LogicalTensorDesc> inputs) {
return backward_graph.cast_final_safe<BackwardGraph>()
.graph().infer_attrs(inputs);
.graph().infer_attrs(inputs);
}
OP_TRAIT_REG(BackwardGraph, BackwardGraph)
......
......@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node(
}
}
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<BatchNorm>();
......@@ -66,7 +66,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
out_shapes[i] = {i1.layout, i1.comp_node};
}
out_shapes[nr_out-1] = {i0.layout, i0.comp_node};
return out_shapes;
return {out_shapes, true};
}
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
......
......@@ -47,7 +47,7 @@ bool valid_broadcast(const TensorShape& src_shape,
return true;
}
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
def.cast_final_safe<Broadcast>();
......@@ -59,7 +59,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
TensorLayout out_layout = src.layout;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{out_layout, src.comp_node}};
return {{{out_layout, src.comp_node}}, true};
}
mgb_assert(
tshp.layout.ndim == 1,
......@@ -77,7 +77,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
src.layout.TensorShape::to_string().c_str(),
out_layout.TensorShape::to_string().c_str());
return {{out_layout, src.comp_node}};
return {{{out_layout, src.comp_node}}, true};
}
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
......
......@@ -25,7 +25,7 @@ namespace {
class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy {
using Output = std::array<TensorPtr, 2>;
CompNode m_cn;
Output m_out;
......@@ -110,14 +110,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return out;
}
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto cn = inputs[0].comp_node;
return {
return {{
{TensorLayout(inputs[0].layout.dtype), cn},
{TensorLayout(dtype::Int32()), cn}
};
}, true};
}
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
......@@ -128,4 +128,4 @@ OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
} // namespace
} // namespace mgb::imperative
\ No newline at end of file
} // namespace mgb::imperative
......@@ -29,7 +29,7 @@ cg::OperatorNodeBase* apply_on_var_node(
return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
}
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
......@@ -40,7 +40,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
TensorShapeArray inp_shapes;
DType out_dt;
CompNode out_cn;
for (size_t i = 0; i < inputs.size(); ++ i) {
for (size_t i = 0; i < inputs.size(); ++ i) {
auto &&t = inputs[i];
if (!i) {
out_cn = t.comp_node;
......@@ -55,12 +55,12 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
TensorLayout out_layout;
out_layout.ndim = 0;
out_layout.dtype = out_dt;
return {{out_layout, out_cn}};
return {{{out_layout, out_cn}}, true};
}
}
auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
return {{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}};
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
}
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
......
......@@ -40,21 +40,21 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return {Tensor::make(std::move(hv))};
}
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& desc = inputs[0];
if (!desc.layout.ndim) {
return {{TensorLayout(dtype::Int32()), desc.comp_node}};
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true};
}
DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) {
ptr[i] = desc.layout[i];
}
return {{value.layout(), desc.comp_node, std::move(value)}};
return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
......
......@@ -28,12 +28,13 @@ namespace {
CompNode::UnorderedSet collect_comp_nodes(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
CompNode::UnorderedSet comp_nodes;
SmallVector<LogicalTensorDesc> descs;
SmallVector<LogicalTensorDesc> inp_descs;
for (auto&& i : inputs) {
comp_nodes.insert(i->comp_node());
descs.push_back({i->layout(), i->comp_node(), {}});
inp_descs.push_back({i->layout(), i->comp_node(), {}});
}
for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) {
SmallVector<LogicalTensorDesc> oup_descs = std::get<0>(def.infer_output_attrs_fallible(def, inp_descs));
for (auto&& output_attr : oup_descs) {
comp_nodes.insert(output_attr.comp_node);
}
return comp_nodes;
......
......@@ -14,6 +14,7 @@
#include "megbrain/graph/static_infer.h"
#include "megbrain/graph/operator_node.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/backward_graph.h"
......@@ -590,10 +591,9 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node();
}
auto opr = OpDef::apply_on_var_node(opdef, vinputs);
mgb_assert(opr->dyn_typeinfo() != InputPlaceholder::typeinfo());
mgb_assert(!opr->same_type<InputPlaceholder>());
for (auto &&i : opr->input()) {
mgb_assert(i->owner_opr()->dyn_typeinfo() ==
InputPlaceholder::typeinfo());
mgb_assert(i->owner_opr()->same_type<InputPlaceholder>());
}
return opr;
}
......@@ -605,17 +605,18 @@ size_t ProxyGraph::get_opr_output_size(const OpDef& opdef,
return get_proxy_opr(opdef, inputs)->usable_output().size();
}
SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs_fallible(
std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::infer_output_attrs_fallible(
const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& inputs) {
auto opr = get_proxy_opr(opdef, inputs);
CUR_OPR_GUARD(opr);
do_shape_infer(false);
SmallVector<LogicalTensorDesc> ret;
SmallVector<LogicalTensorDesc> outputs;
bool validated = do_shape_infer(false);
for (auto&& i : opr->usable_output()) {
ret.push_back({{i->shape(), i->dtype()}, i->comp_node()});
outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()});
}
return ret;
bool need_check = opr->same_type<opr::Reshape>();
return {outputs, validated && !need_check};
}
struct ProxyGraph::GradGraph {
......@@ -811,16 +812,20 @@ VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTenso
/*********************** Common Impl ***********************/
void ProxyGraph::do_shape_infer(bool sync_value) {
bool ProxyGraph::do_shape_infer(bool sync_value) {
m_static_infer_manager->update();
bool validated = true;
for (auto* var : m_cur_opr->output()) {
if (sync_value) {
var->shape(m_static_infer_manager->infer_shape(var));
} else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) {
var->shape(*shape);
var->shape(*shape);
} else {
validated = false;
}
}
return validated;
}
TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
......
......@@ -48,7 +48,7 @@ public:
const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& inputs);
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& inputs);
......@@ -88,7 +88,7 @@ private:
/********************** Common Helper **********************/
void do_shape_infer(bool sync_value);
bool do_shape_infer(bool sync_value);
TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share=true);
......
......@@ -80,8 +80,7 @@ apply_on_physical_tensor(const OpDef& def,
return outputs;
}
SmallVector<LogicalTensorDesc>
infer_output_attrs_fallible(const OpDef& def,
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs_fallible(def, inputs);
......@@ -136,4 +135,4 @@ make_backward_graph(const OpDef& def,
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -21,8 +21,7 @@ SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs);
SmallVector<LogicalTensorDesc>
infer_output_attrs_fallible(const OpDef& def,
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs);
BackwardGraphResult
......@@ -35,4 +34,4 @@ make_backward_graph(const OpDef& def,
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -44,7 +44,7 @@ public:
const OpDef& def,
const VarNodeArray& inputs);
static SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs);
......
......@@ -38,8 +38,8 @@ public:
SmallVector<TensorPtr>
apply(const SmallVector<TensorPtr>& inputs) const;
SmallVector<LogicalTensorDesc>
infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const;
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_attrs(
const SmallVector<LogicalTensorDesc>& inputs) const;
template <typename T, typename F, typename C>
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册