提交 518c7f37 编写于 作者: M Megvii Engine Team 提交者: dengzheye

fix(imperative/src): fix empty_tensor bug of rng

GitOrigin-RevId: 4c948f41f04649620ce7b34c5f3dac69d66705e2
上级 cca38c4e
...@@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
template <typename Op> template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
LogicalTensorDesc dest; LogicalTensorDesc dest;
auto&& xxx_rng_def = def.cast_final_safe<Op>(); auto&& xxx_rng_def = def.cast_final_safe<Op>();
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
...@@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
xxx_rng_def.dyn_typeinfo()->name, nr_inp); xxx_rng_def.dyn_typeinfo()->name, nr_inp);
} }
dest.comp_node = inputs[0].comp_node; dest.comp_node = inputs[0].comp_node;
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); if (success) {
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
} else {
dest.layout = TensorLayout(inputs[0].layout.dtype);
}
return {{dest}, inputs[0].layout.ndim != 0}; return {{dest}, inputs[0].layout.ndim != 0};
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册