提交 39417856 编写于 作者: P panyifeng

fix valuenode simplify

上级 bc0a53cf
...@@ -43,26 +43,28 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) { ...@@ -43,26 +43,28 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
return nullptr; return nullptr;
} }
AbstractBasePtr res = t;
if (t->isa<AbstractClass>()) { if (t->isa<AbstractClass>()) {
auto abs_class = dyn_cast<AbstractClass>(t); auto abs_class = dyn_cast<AbstractClass>(t);
AbstractBasePtrList baselist; AbstractBasePtrList baselist;
auto attributes = abs_class->attributes(); auto attributes = abs_class->attributes();
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
[](const AbstractAttribute &item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist); return std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractDictionary>()) { }
if (t->isa<AbstractDictionary>()) {
auto abs_dict = dyn_cast<AbstractDictionary>(t); auto abs_dict = dyn_cast<AbstractDictionary>(t);
AbstractBasePtrList baselist; AbstractBasePtrList baselist;
auto elements = abs_dict->elements(); auto elements = abs_dict->elements();
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
[](const AbstractAttribute &item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist); return std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractList>()) { }
auto abs_dict = dyn_cast<AbstractList>(t); if (t->isa<AbstractList>()) {
res = std::make_shared<AbstractTuple>(abs_dict->elements()); auto abs_list = dyn_cast<AbstractList>(t);
return std::make_shared<AbstractTuple>(abs_list->elements());
} }
return res;
return nullptr;
} }
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
...@@ -376,7 +378,12 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr ...@@ -376,7 +378,12 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
for (auto &node : manager->all_nodes()) { for (auto &node : manager->all_nodes()) {
auto ret = Reabs(node->abstract()); auto ret = Reabs(node->abstract());
node->set_abstract(ret); if (ret) {
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
<< ret->ToString();
node->set_abstract(ret);
changed = true;
}
} }
return changed; return changed;
} }
......
...@@ -1031,3 +1031,13 @@ def test_grad_if_defer_inline(): ...@@ -1031,3 +1031,13 @@ def test_grad_if_defer_inline():
inp = Tensor(np.ones([128, 96]).astype(np.float32)) inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp) grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
def test_dict_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.res = {'1': 10}
def construct(self):
return self.res
Net()()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册