提交 cad8568c 编写于 作者: M Megvii Engine Team

fix(mge/optimizer): fix optimizer's state_dict bug

GitOrigin-RevId: 67fb112fb8e4d7b295a6f4a2a2c8254002c97bbc
上级 0ed36998
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Dict from typing import Dict
...@@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta): ...@@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta):
cur_id += 1 cur_id += 1
for param, st in self._state.items(): for param, st in self._state.items():
_st = copy.copy(st)
if not keep_var: if not keep_var:
for k, v in st.items(): for k, v in st.items():
st[k] = v.numpy() _st[k] = v.numpy()
state[param2id[param]] = st state[param2id[param]] = _st
for group in self.param_groups: for group in self.param_groups:
param_group = {k: v for k, v in group.items() if k != "params"} param_group = {k: v for k, v in group.items() if k != "params"}
......
...@@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
) )
step += 1 step += 1
check_func(ori_params, net.parameters(), step) check_func(ori_params, net.parameters(), step)
try_state_dict = {
"net": net.state_dict(),
"opt": opt.state_dict(),
}
def test_sgd(): def test_sgd():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册