提交 634de590 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add valid flag of `infer_output_attrs_fallible`

GitOrigin-RevId: b2b32774eeb893503c25d3434fa6f2ba64f1c8c6
上级 50c4daac
...@@ -63,15 +63,17 @@ SmallVector<void*> ChannelImpl::apply_op( ...@@ -63,15 +63,17 @@ SmallVector<void*> ChannelImpl::apply_op(
input_infos.push_back(info); input_infos.push_back(info);
input_descs.push_back(info->desc); input_descs.push_back(info->desc);
} }
auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs);
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
ApplyOp cmd{std::move(op)}; ApplyOp cmd{std::move(op)};
cmd.inputs = std::move(input_infos); cmd.inputs = std::move(input_infos);
cmd.outputs.reserve(output_descs.size()); cmd.outputs.reserve(output_descs.size());
SmallVector<void*> outputs; SmallVector<void*> outputs;
bool is_fallible = false; // FIXME: remove this check when op check is correct
bool validated_bkp = true;
for (auto&& desc : output_descs) { for (auto&& desc : output_descs) {
if (desc.layout.ndim == 0) { if (desc.layout.ndim == 0) {
is_fallible = true; validated_bkp = false;
} }
auto info = alloc(); auto info = alloc();
info->desc = desc; info->desc = desc;
...@@ -80,8 +82,14 @@ SmallVector<void*> ChannelImpl::apply_op( ...@@ -80,8 +82,14 @@ SmallVector<void*> ChannelImpl::apply_op(
outputs.push_back(info); outputs.push_back(info);
} }
m_worker.add_task(std::move(cmd)); m_worker.add_task(std::move(cmd));
if (is_fallible && m_async_level <= 1) { if (!(validated && validated_bkp) && m_async_level == 1) {
sync();
} else if (m_async_level == 0) {
sync(); sync();
// check device error
for (auto&& oup : cmd.outputs) {
oup->ptr->comp_node().sync();
}
} }
return outputs; return outputs;
} }
...@@ -194,6 +202,9 @@ ChannelImpl::~ChannelImpl() { ...@@ -194,6 +202,9 @@ ChannelImpl::~ChannelImpl() {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
dest->value_fetched = ptr->value_fetched(); dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr); dest->ptr = std::move(ptr);
if (m_waitee == dest) { if (m_waitee == dest) {
m_cv.notify_all(); m_cv.notify_all();
......
...@@ -42,7 +42,7 @@ cg::OperatorNodeBase* OpDef::apply_on_var_node( ...@@ -42,7 +42,7 @@ cg::OperatorNodeBase* OpDef::apply_on_var_node(
return def.trait()->apply_on_var_node(def, inputs); return def.trait()->apply_on_var_node(def, inputs);
} }
SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
return def.trait()->infer_output_attrs_fallible(def, inputs); return def.trait()->infer_output_attrs_fallible(def, inputs);
......
...@@ -24,12 +24,12 @@ BackwardGraph::InternalGraph::apply( ...@@ -24,12 +24,12 @@ BackwardGraph::InternalGraph::apply(
inputs); inputs);
} }
SmallVector<LogicalTensorDesc> std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::infer_attrs(
BackwardGraph::InternalGraph::infer_attrs(
const SmallVector<LogicalTensorDesc>& inputs) const { const SmallVector<LogicalTensorDesc>& inputs) const {
using TensorAttr = LogicalTensorDesc; using TensorAttr = LogicalTensorDesc;
ThinHashMap<size_t, TensorAttr> node2attr; ThinHashMap<size_t, TensorAttr> node2attr;
auto&& input_nodes = this->inputs; auto&& input_nodes = this->inputs;
auto&& output_nodes = this->outputs;
mgb_assert(inputs.size() == input_nodes.size()); mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) { for (size_t i = 0; i < inputs.size(); ++ i) {
node2attr[input_nodes[i]] = inputs[i]; node2attr[input_nodes[i]] = inputs[i];
...@@ -41,25 +41,29 @@ BackwardGraph::InternalGraph::infer_attrs( ...@@ -41,25 +41,29 @@ BackwardGraph::InternalGraph::infer_attrs(
i.second->layout(), i.second->comp_node(), i.second->layout(), i.second->comp_node(),
value->proxy_to_default_cpu()}; value->proxy_to_default_cpu()};
} }
bool validated = true;
for (size_t i = 0; i < exprs.size(); ++ i) { for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& expr = exprs[i]; auto&& [expr_op, expr_inps, expr_oups] = exprs[i];
SmallVector<TensorAttr> inputs; SmallVector<TensorAttr> expr_input_descs;
for (auto &&in : std::get<1>(expr)) { for (auto &&inp : expr_inps) {
inputs.push_back(node2attr.at(in)); expr_input_descs.push_back(node2attr.at(inp));
} }
auto outputs = OpDef::infer_output_attrs_fallible(
*std::get<0>(expr), inputs); auto[expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible(
auto output_nodes = std::get<2>(expr); *expr_op, expr_input_descs);
mgb_assert(outputs.size() == output_nodes.size()); validated = validated && expr_validated;
for (size_t i = 0; i < outputs.size(); ++ i) {
node2attr[output_nodes[i]] = outputs[i]; mgb_assert(expr_output_descs.size() == expr_oups.size());
for (size_t i = 0; i < expr_output_descs.size(); ++ i) {
node2attr[expr_oups[i]] = expr_output_descs[i];
} }
} }
SmallVector<TensorAttr> ret; SmallVector<TensorAttr> ret;
for (auto &&i : outputs) { for (auto &&i : output_nodes) {
ret.push_back(node2attr.at(i)); ret.push_back(node2attr.at(i));
} }
return ret; return {ret, validated};
} }
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph);
...@@ -72,11 +76,11 @@ SmallVector<TensorPtr> backward_impl( ...@@ -72,11 +76,11 @@ SmallVector<TensorPtr> backward_impl(
.graph().apply(tensors); .graph().apply(tensors);
} }
SmallVector<LogicalTensorDesc> infer_tensor_attrs( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs(
const OpDef& backward_graph, const OpDef& backward_graph,
const SmallVector<LogicalTensorDesc> inputs) { const SmallVector<LogicalTensorDesc> inputs) {
return backward_graph.cast_final_safe<BackwardGraph>() return backward_graph.cast_final_safe<BackwardGraph>()
.graph().infer_attrs(inputs); .graph().infer_attrs(inputs);
} }
OP_TRAIT_REG(BackwardGraph, BackwardGraph) OP_TRAIT_REG(BackwardGraph, BackwardGraph)
......
...@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node( ...@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node(
} }
} }
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<BatchNorm>(); auto&& op_def = def.cast_final_safe<BatchNorm>();
...@@ -66,7 +66,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( ...@@ -66,7 +66,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
out_shapes[i] = {i1.layout, i1.comp_node}; out_shapes[i] = {i1.layout, i1.comp_node};
} }
out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; out_shapes[nr_out-1] = {i0.layout, i0.comp_node};
return out_shapes; return {out_shapes, true};
} }
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
......
...@@ -47,7 +47,7 @@ bool valid_broadcast(const TensorShape& src_shape, ...@@ -47,7 +47,7 @@ bool valid_broadcast(const TensorShape& src_shape,
return true; return true;
} }
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
def.cast_final_safe<Broadcast>(); def.cast_final_safe<Broadcast>();
...@@ -59,7 +59,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( ...@@ -59,7 +59,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
TensorLayout out_layout = src.layout; TensorLayout out_layout = src.layout;
if (tshp.layout.ndim == 0 || tshp.value.empty()) { if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0; out_layout.ndim = 0;
return {{out_layout, src.comp_node}}; return {{{out_layout, src.comp_node}}, true};
} }
mgb_assert( mgb_assert(
tshp.layout.ndim == 1, tshp.layout.ndim == 1,
...@@ -77,7 +77,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( ...@@ -77,7 +77,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
src.layout.TensorShape::to_string().c_str(), src.layout.TensorShape::to_string().c_str(),
out_layout.TensorShape::to_string().c_str()); out_layout.TensorShape::to_string().c_str());
return {{out_layout, src.comp_node}}; return {{{out_layout, src.comp_node}}, true};
} }
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
......
...@@ -25,7 +25,7 @@ namespace { ...@@ -25,7 +25,7 @@ namespace {
class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy {
using Output = std::array<TensorPtr, 2>; using Output = std::array<TensorPtr, 2>;
CompNode m_cn; CompNode m_cn;
Output m_out; Output m_out;
...@@ -110,14 +110,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -110,14 +110,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return out; return out;
} }
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
auto cn = inputs[0].comp_node; auto cn = inputs[0].comp_node;
return { return {{
{TensorLayout(inputs[0].layout.dtype), cn}, {TensorLayout(inputs[0].layout.dtype), cn},
{TensorLayout(dtype::Int32()), cn} {TensorLayout(dtype::Int32()), cn}
}; }, true};
} }
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
...@@ -128,4 +128,4 @@ OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) ...@@ -128,4 +128,4 @@ OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
} // namespace } // namespace
} // namespace mgb::imperative } // namespace mgb::imperative
\ No newline at end of file
...@@ -29,7 +29,7 @@ cg::OperatorNodeBase* apply_on_var_node( ...@@ -29,7 +29,7 @@ cg::OperatorNodeBase* apply_on_var_node(
return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
} }
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>(); auto&& op_def = def.cast_final_safe<Elemwise>();
...@@ -40,7 +40,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( ...@@ -40,7 +40,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
TensorShapeArray inp_shapes; TensorShapeArray inp_shapes;
DType out_dt; DType out_dt;
CompNode out_cn; CompNode out_cn;
for (size_t i = 0; i < inputs.size(); ++ i) { for (size_t i = 0; i < inputs.size(); ++ i) {
auto &&t = inputs[i]; auto &&t = inputs[i];
if (!i) { if (!i) {
out_cn = t.comp_node; out_cn = t.comp_node;
...@@ -55,12 +55,12 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( ...@@ -55,12 +55,12 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
TensorLayout out_layout; TensorLayout out_layout;
out_layout.ndim = 0; out_layout.ndim = 0;
out_layout.dtype = out_dt; out_layout.dtype = out_dt;
return {{out_layout, out_cn}}; return {{{out_layout, out_cn}}, true};
} }
} }
auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
return {{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}; return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
} }
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
......
...@@ -40,21 +40,21 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -40,21 +40,21 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return {Tensor::make(std::move(hv))}; return {Tensor::make(std::move(hv))};
} }
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
def.cast_final_safe<GetVarShape>(); def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& desc = inputs[0]; auto&& desc = inputs[0];
if (!desc.layout.ndim) { if (!desc.layout.ndim) {
return {{TensorLayout(dtype::Int32()), desc.comp_node}}; return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true};
} }
DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>(); auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) { for (size_t i = 0; i < desc.layout.ndim; ++i) {
ptr[i] = desc.layout[i]; ptr[i] = desc.layout[i];
} }
return {{value.layout(), desc.comp_node, std::move(value)}}; return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
} }
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
......
...@@ -28,12 +28,13 @@ namespace { ...@@ -28,12 +28,13 @@ namespace {
CompNode::UnorderedSet collect_comp_nodes( CompNode::UnorderedSet collect_comp_nodes(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
CompNode::UnorderedSet comp_nodes; CompNode::UnorderedSet comp_nodes;
SmallVector<LogicalTensorDesc> descs; SmallVector<LogicalTensorDesc> inp_descs;
for (auto&& i : inputs) { for (auto&& i : inputs) {
comp_nodes.insert(i->comp_node()); comp_nodes.insert(i->comp_node());
descs.push_back({i->layout(), i->comp_node(), {}}); inp_descs.push_back({i->layout(), i->comp_node(), {}});
} }
for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) { SmallVector<LogicalTensorDesc> oup_descs = std::get<0>(def.infer_output_attrs_fallible(def, inp_descs));
for (auto&& output_attr : oup_descs) {
comp_nodes.insert(output_attr.comp_node); comp_nodes.insert(output_attr.comp_node);
} }
return comp_nodes; return comp_nodes;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "megbrain/graph/static_infer.h" #include "megbrain/graph/static_infer.h"
#include "megbrain/graph/operator_node.h" #include "megbrain/graph/operator_node.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
...@@ -590,10 +591,9 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( ...@@ -590,10 +591,9 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node();
} }
auto opr = OpDef::apply_on_var_node(opdef, vinputs); auto opr = OpDef::apply_on_var_node(opdef, vinputs);
mgb_assert(opr->dyn_typeinfo() != InputPlaceholder::typeinfo()); mgb_assert(!opr->same_type<InputPlaceholder>());
for (auto &&i : opr->input()) { for (auto &&i : opr->input()) {
mgb_assert(i->owner_opr()->dyn_typeinfo() == mgb_assert(i->owner_opr()->same_type<InputPlaceholder>());
InputPlaceholder::typeinfo());
} }
return opr; return opr;
} }
...@@ -605,17 +605,18 @@ size_t ProxyGraph::get_opr_output_size(const OpDef& opdef, ...@@ -605,17 +605,18 @@ size_t ProxyGraph::get_opr_output_size(const OpDef& opdef,
return get_proxy_opr(opdef, inputs)->usable_output().size(); return get_proxy_opr(opdef, inputs)->usable_output().size();
} }
SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::infer_output_attrs_fallible(
const OpDef& opdef, const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
auto opr = get_proxy_opr(opdef, inputs); auto opr = get_proxy_opr(opdef, inputs);
CUR_OPR_GUARD(opr); CUR_OPR_GUARD(opr);
do_shape_infer(false); SmallVector<LogicalTensorDesc> outputs;
SmallVector<LogicalTensorDesc> ret; bool validated = do_shape_infer(false);
for (auto&& i : opr->usable_output()) { for (auto&& i : opr->usable_output()) {
ret.push_back({{i->shape(), i->dtype()}, i->comp_node()}); outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()});
} }
return ret; bool need_check = opr->same_type<opr::Reshape>();
return {outputs, validated && !need_check};
} }
struct ProxyGraph::GradGraph { struct ProxyGraph::GradGraph {
...@@ -811,16 +812,20 @@ VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTenso ...@@ -811,16 +812,20 @@ VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTenso
/*********************** Common Impl ***********************/ /*********************** Common Impl ***********************/
void ProxyGraph::do_shape_infer(bool sync_value) { bool ProxyGraph::do_shape_infer(bool sync_value) {
m_static_infer_manager->update(); m_static_infer_manager->update();
bool validated = true;
for (auto* var : m_cur_opr->output()) { for (auto* var : m_cur_opr->output()) {
if (sync_value) { if (sync_value) {
var->shape(m_static_infer_manager->infer_shape(var)); var->shape(m_static_infer_manager->infer_shape(var));
} else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) { } else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) {
var->shape(*shape); var->shape(*shape);
} else {
validated = false;
} }
} }
return validated;
} }
TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
......
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
const OpDef& opdef, const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& opdef, const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);
...@@ -88,7 +88,7 @@ private: ...@@ -88,7 +88,7 @@ private:
/********************** Common Helper **********************/ /********************** Common Helper **********************/
void do_shape_infer(bool sync_value); bool do_shape_infer(bool sync_value);
TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share=true); TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share=true);
......
...@@ -80,8 +80,7 @@ apply_on_physical_tensor(const OpDef& def, ...@@ -80,8 +80,7 @@ apply_on_physical_tensor(const OpDef& def,
return outputs; return outputs;
} }
SmallVector<LogicalTensorDesc> std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
auto&& graph = ProxyGraph::get_default_graph(); auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs_fallible(def, inputs); return graph->infer_output_attrs_fallible(def, inputs);
...@@ -136,4 +135,4 @@ make_backward_graph(const OpDef& def, ...@@ -136,4 +135,4 @@ make_backward_graph(const OpDef& def,
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
...@@ -21,8 +21,7 @@ SmallVector<TensorPtr> ...@@ -21,8 +21,7 @@ SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def, apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs); const SmallVector<TensorPtr>& inputs);
SmallVector<LogicalTensorDesc> std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);
BackwardGraphResult BackwardGraphResult
...@@ -35,4 +34,4 @@ make_backward_graph(const OpDef& def, ...@@ -35,4 +34,4 @@ make_backward_graph(const OpDef& def,
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs); const VarNodeArray& inputs);
static SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);
......
...@@ -38,8 +38,8 @@ public: ...@@ -38,8 +38,8 @@ public:
SmallVector<TensorPtr> SmallVector<TensorPtr>
apply(const SmallVector<TensorPtr>& inputs) const; apply(const SmallVector<TensorPtr>& inputs) const;
SmallVector<LogicalTensorDesc> std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_attrs(
infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const; const SmallVector<LogicalTensorDesc>& inputs) const;
template <typename T, typename F, typename C> template <typename T, typename F, typename C>
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const { SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册