提交 9c17cfc4 编写于 作者: M Megvii Engine Team

fix(imperative/ops): add check_dtype for Elemwise in infer_attrs

GitOrigin-RevId: c7778557537f94648e1e706af00fe2d917f18fbd
上级 05550bc5
......@@ -65,6 +65,25 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{out_layout, out_cn}}, false};
}
}
// copy from megdnn::ElemwiseForward::check_dtype
switch (out_dt.category()) {
case DTypeCategory::FLOAT:
mgb_assert(trait.allow_float, "unsupport mode %s for float\n",
trait.name);
break;
case DTypeCategory::INT:
mgb_assert(trait.allow_int, "unsupport mode %s for int\n",
trait.name);
break;
case DTypeCategory::BOOL:
mgb_assert(trait.allow_bool, "unsupport mode %s for bool\n",
trait.name);
break;
default:
// Quantized Dtype could also be handled by this op,
// but scales need to be the same.
break;
}
auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册