提交 518c7f37 编写于 作者: M Megvii Engine Team 提交者: dengzheye

fix(imperative/src): fix empty_tensor bug of rng

GitOrigin-RevId: 4c948f41f04649620ce7b34c5f3dac69d66705e2
上级 cca38c4e
......@@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
LogicalTensorDesc dest;
auto&& xxx_rng_def = def.cast_final_safe<Op>();
size_t nr_inp = inputs.size();
......@@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
xxx_rng_def.dyn_typeinfo()->name, nr_inp);
}
dest.comp_node = inputs[0].comp_node;
if (success) {
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
} else {
dest.layout = TensorLayout(inputs[0].layout.dtype);
}
return {{dest}, inputs[0].layout.ndim != 0};
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册