diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 7123e13c62425d382158b7f669d27fa473cd553c..98d1c7b26b5382230c160736fc8686804dc24d22 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -20,22 +20,30 @@ namespace { cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { - def.cast_final_safe(); - return opr::GetVarShape::make(inputs).node()->owner_opr(); + auto&& op_def = def.cast_final_safe(); + return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); } SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { - def.cast_final_safe(); + auto&& op_def = def.cast_final_safe(); mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); auto&& inp = inputs[0]; auto&& shp = inp->layout(); mgb_assert(shp.ndim != 0, "input shape invalid"); - HostTensorND hv(inp->comp_node(), {shp.ndim}, dtype::Int32()); - auto* ptr = hv.ptr(); - for (size_t i = 0; i < shp.ndim; ++i) { - ptr[i] = shp.shape[i]; + HostTensorND hv; + if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ + hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32()); + auto* ptr = hv.ptr(); + for (size_t i = 0; i < shp.ndim; ++i) { + ptr[i] = shp.shape[i]; + } + }else{ + mgb_assert(op_def.axis < shp.ndim); + hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); + auto* ptr = hv.ptr(); + ptr[0] = shp.shape[op_def.axis]; } return {Tensor::make(std::move(hv))}; } @@ -43,29 +51,31 @@ SmallVector apply_on_physical_tensor( std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { - def.cast_final_safe(); + auto&& op_def = def.cast_final_safe(); mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); auto&& desc = inputs[0]; if (!desc.layout.ndim) { return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true}; } - DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); - auto* ptr = value.ptr(); - for (size_t i = 0; i < desc.layout.ndim; ++i) { - ptr[i] = desc.layout[i]; + DeviceTensorND value; + if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ + value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); + auto* ptr = value.ptr(); + for (size_t i = 0; i < desc.layout.ndim; ++i) { + ptr[i] = desc.layout[i]; + } + }else{ + mgb_assert(op_def.axis < desc.layout.ndim); + value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); + auto* ptr = value.ptr(); + ptr[0] = desc.layout[op_def.axis]; } return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; } std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); - if (node->config().comp_node().size() || - node->config().output_dtype().valid() || - node->param().axis != opr::GetVarShape::Param::INVALID_AXIS) { - mgb_log_debug("weird GetVarShape"); - return OpTrait::find_by_typeinfo(OprAttr::typeinfo())->make_from_op_node(node); - } - return GetVarShape::make(); + return GetVarShape::make(node->param()); } OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index c96d3cda877909d21b6dca65f21528432be73f47..9d8a54d0c133c47e0ffa115ae19bc96b9d8f775a 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -122,7 +122,7 @@ def Eye: MgbHashableOp<"Eye", [EyeParam]> { ); } -def GetVarShape : MgbHashableOp<"GetVarShape">; +def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>; def Concat: MgbHashableOp<"Concat", [AxisParam]> { let extraArguments = (ins