提交 f3863810 编写于 作者: M Megvii Engine Team

fix(imperative): fix inplace operation of optim

GitOrigin-RevId: 2aaa71eb66c1096d117ed70d2cadae3f85e32ab6
上级 9330929f
...@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ str(type(param)) + str(type(param))
) )
param[...] = Tensor(param.numpy(), no_cache=True) param[...] = Tensor(param, no_cache=True)
for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:
......
import numpy as np import numpy as np
import pytest import pytest
import megengine as mge
import megengine.autodiff as autodiff
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.optimizer as optim
from megengine import Parameter, Tensor, amp from megengine import Parameter, Tensor, amp
from megengine.core._config import set_auto_format_convert from megengine.core._config import set_auto_format_convert
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
...@@ -57,3 +60,42 @@ def test_convert_module(is_inplace): ...@@ -57,3 +60,42 @@ def test_convert_module(is_inplace):
) )
else: else:
assert param.shape == expected_shape[name], name assert param.shape == expected_shape[name], name
class Module(M.Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = M.BatchNorm2d(16)
def forward(self, x):
out = F.relu(self.bn(self.conv(x)))
return out
def test_format_remained():
m = Module()
m = amp.convert_module_format(m)
gm = autodiff.GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), lr=0.01)
scaler = amp.GradScaler()
image = mge.tensor(np.random.normal(size=(1, 3, 224, 224)), dtype="float32")
label = mge.tensor(np.ones((1, 224, 224)), dtype="int32")
image = amp.convert_tensor_format(image)
@amp.autocast(enabled=True)
def train_step(image):
with gm:
logits = m(image)
loss = F.nn.cross_entropy(logits, label)
scaler.backward(gm, loss)
opt.step().clear_grad()
return logits
for _ in range(5):
res = train_step(image)
assert res.format == "nhwc"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册