提交 4917534b 编写于 作者: M Megvii Engine Team

feat(mge/optimizer): save state's numpy value by default in `state_dict`

GitOrigin-RevId: ec7e4d56f54f724c039b462906583bf025d060e6
上级 84f990a0
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import contextmanager
from typing import Dict from typing import Dict
from typing import Iterable as Iter from typing import Iterable as Iter
from typing import Union from typing import Union
...@@ -180,7 +179,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -180,7 +179,7 @@ class Optimizer(metaclass=ABCMeta):
param.grad = None param.grad = None
pop_scope("clear_grad") pop_scope("clear_grad")
def state_dict(self) -> Dict: def state_dict(self, keep_var=False) -> Dict:
r""" r"""
Export the optimizer state. Export the optimizer state.
...@@ -198,6 +197,9 @@ class Optimizer(metaclass=ABCMeta): ...@@ -198,6 +197,9 @@ class Optimizer(metaclass=ABCMeta):
cur_id += 1 cur_id += 1
for param, st in self._state.items(): for param, st in self._state.items():
if not keep_var:
for k, v in st.items():
st[k] = v.numpy()
state[param2id[param]] = st state[param2id[param]] = st
for group in self.param_groups: for group in self.param_groups:
...@@ -218,7 +220,6 @@ class Optimizer(metaclass=ABCMeta): ...@@ -218,7 +220,6 @@ class Optimizer(metaclass=ABCMeta):
raise ValueError( raise ValueError(
"loaded state dict has a different number of parameter groups" "loaded state dict has a different number of parameter groups"
) )
parameter_map = dict() # type: Dict
for group_new, group_saved in zip(self.param_groups, state["param_groups"]): for group_new, group_saved in zip(self.param_groups, state["param_groups"]):
if len(group_new["params"]) != len(group_saved["params"]): if len(group_new["params"]) != len(group_saved["params"]):
raise ValueError( raise ValueError(
...@@ -232,8 +233,9 @@ class Optimizer(metaclass=ABCMeta): ...@@ -232,8 +233,9 @@ class Optimizer(metaclass=ABCMeta):
self._state[p] = state["state"][param_saved].copy() self._state[p] = state["state"][param_saved].copy()
for k, v in self._state[p].items(): for k, v in self._state[p].items():
if isinstance(v, Tensor): if isinstance(v, Tensor):
# TODO: maybe a more efficient way? self._state[p][k] = v.detach()
self._state[p][k] = Tensor(v.numpy()) else:
self._state[p][k] = Tensor(v)
if set(group_new.keys()) != set(group_saved.keys()): if set(group_new.keys()) != set(group_saved.keys()):
raise ValueError( raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册