提交 a78c1109 编写于 作者: M Megvii Engine Team

fix(imperative): add param(axis) for GetVarShape

GitOrigin-RevId: 0b8f821929dc2ad640ac8c5d0a6c13bad519a952
上级 cde9727a
......@@ -20,22 +20,30 @@ namespace {
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
def.cast_final_safe<GetVarShape>();
return opr::GetVarShape::make(inputs).node()->owner_opr();
auto&& op_def = def.cast_final_safe<GetVarShape>();
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr();
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
def.cast_final_safe<GetVarShape>();
auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp->layout();
mgb_assert(shp.ndim != 0, "input shape invalid");
HostTensorND hv(inp->comp_node(), {shp.ndim}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = shp.shape[i];
HostTensorND hv;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){
hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = shp.shape[i];
}
}else{
mgb_assert(op_def.axis < shp.ndim);
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = shp.shape[op_def.axis];
}
return {Tensor::make(std::move(hv))};
}
......@@ -43,29 +51,31 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
def.cast_final_safe<GetVarShape>();
auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& desc = inputs[0];
if (!desc.layout.ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true};
}
DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) {
ptr[i] = desc.layout[i];
DeviceTensorND value;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) {
ptr[i] = desc.layout[i];
}
}else{
mgb_assert(op_def.axis < desc.layout.ndim);
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
ptr[0] = desc.layout[op_def.axis];
}
return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::GetVarShape>();
if (node->config().comp_node().size() ||
node->config().output_dtype().valid() ||
node->param().axis != opr::GetVarShape::Param::INVALID_AXIS) {
mgb_log_debug("weird GetVarShape");
return OpTrait::find_by_typeinfo(OprAttr::typeinfo())->make_from_op_node(node);
}
return GetVarShape::make();
return GetVarShape::make(node->param());
}
OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
......
......@@ -122,7 +122,7 @@ def Eye: MgbHashableOp<"Eye", [EyeParam]> {
);
}
def GetVarShape : MgbHashableOp<"GetVarShape">;
def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>;
def Concat: MgbHashableOp<"Concat", [AxisParam]> {
let extraArguments = (ins
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册