diff --git a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h index f1f73de4d930f46f84665f01b052e3c7f8b146ed..fb43f6ffd8acd6eaef363789e9ba21591638e1e1 100644 --- a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h @@ -50,11 +50,15 @@ class ReshapeSameShapeEliminater : public AnfVisitor { } auto src_shape = src_shape_abs->GetShapeTrack(); - auto tgt_shape = GetValueNode(shape_); - if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa()) { - auto elements = GetValue>(tgt_shape); + auto tgt_shape_abs = node->abstract(); + if (tgt_shape_abs == nullptr) { + return nullptr; + } + auto tgt_shape = tgt_shape_abs->GetShapeTrack(); + if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa() && tgt_shape->isa()) { + auto elements = tgt_shape->cast(); auto shape = src_shape->cast(); - if (shape->shape() == elements) { + if (shape->shape() == elements->shape()) { return x_; } } diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 8cfbcc9f42d85ffe93d4297406347ec5cb2caff3..804a1f3aa3d6aa34a34ddd9297458b0cbb45562f 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -219,6 +219,7 @@ TEST_F(TestOptLib, test_elim_reshape_same_shape) { tensor::TensorPtr x_tensor = std::make_shared(kFloat32->type_id(), shp); auto x_abstract = x_tensor->ToAbstract(); x_node->set_abstract(x_abstract); + before->output()->set_abstract(x_abstract); } auto patterns = std::vector({irpass.reshape_eliminate_}); ASSERT_TRUE(CheckOpt(before, after, patterns));