提交 e283663a 编写于 作者: M Megvii Engine Team

fix(mge/imperative): update tests to new optimizer api

GitOrigin-RevId: 3d06e3db3c6e057505cfc6df1f2b03c01b6d1470
上级 b5016b9d
...@@ -29,6 +29,7 @@ class GradManager: ...@@ -29,6 +29,7 @@ class GradManager:
def register_after_backward_callback(self, callback): def register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
return self
def backward(self, ys, dys=None): def backward(self, ys, dys=None):
global backwarding_grad_manager global backwarding_grad_manager
......
...@@ -177,6 +177,13 @@ class Grad: ...@@ -177,6 +177,13 @@ class Grad:
dys = aslist(dys) dys = aslist(dys)
assert len(ys) == len(dys) assert len(ys) == len(dys)
ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()]
if len(ids) == 0:
return
ys = [y for i, y in enumerate(ys) if i in ids]
dys = [dy for i, dy in enumerate(dys) if i in ids]
# ys is changed to a list of VariableNode which contains more information # ys is changed to a list of VariableNode which contains more information
# such as OpNode, callback, etc. # such as OpNode, callback, etc.
ys = [i._extra_data[self].node for i in ys] ys = [i._extra_data[self].node for i in ys]
......
...@@ -20,8 +20,8 @@ from ..core.autodiff.grad import ( ...@@ -20,8 +20,8 @@ from ..core.autodiff.grad import (
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor, tensor_apply from ..core.tensor.tensor import Tensor, tensor_apply
from ..tensor import tensor
from ..device import get_default_device from ..device import get_default_device
from ..tensor import tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
__all__ = [ __all__ = [
......
...@@ -11,7 +11,7 @@ from typing import Iterable, Union ...@@ -11,7 +11,7 @@ from typing import Iterable, Union
import numpy as np import numpy as np
from ..functional import sqrt from ..functional import sqrt
from ..tensor_nn import Buffer, Parameter from ..tensor_nn import Parameter
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -63,16 +63,7 @@ class Adadelta(Optimizer): ...@@ -63,16 +63,7 @@ class Adadelta(Optimizer):
for param in param_group["params"]: for param in param_group["params"]:
if param.__wrapped__ in self._grad_skip: if not param.requires_grad or "grad" not in param.__dict__:
self._grad_skip.remove(param.__wrapped__)
continue
if not isinstance(param.grad, Buffer):
raise TypeError(
"grad must be a Buffer, maybe you forget to call backward()?"
)
if not param.requires_grad:
continue continue
states = self._state[param] states = self._state[param]
...@@ -91,5 +82,3 @@ class Adadelta(Optimizer): ...@@ -91,5 +82,3 @@ class Adadelta(Optimizer):
acc_delta = rho * acc_delta + (1 - rho) * delta ** 2 acc_delta = rho * acc_delta + (1 - rho) * delta ** 2
states["square_avg"]._reset(square_avg) states["square_avg"]._reset(square_avg)
states["acc_delta"]._reset(acc_delta) states["acc_delta"]._reset(acc_delta)
assert len(self._grad_skip) == 0
...@@ -11,7 +11,7 @@ from typing import Iterable, Union ...@@ -11,7 +11,7 @@ from typing import Iterable, Union
import numpy as np import numpy as np
from ..functional import sqrt from ..functional import sqrt
from ..tensor_nn import Buffer, Parameter from ..tensor_nn import Parameter
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -62,16 +62,7 @@ class Adagrad(Optimizer): ...@@ -62,16 +62,7 @@ class Adagrad(Optimizer):
for param in param_group["params"]: for param in param_group["params"]:
if param.__wrapped__ in self._grad_skip: if not param.requires_grad or "grad" not in param.__dict__:
self._grad_skip.remove(param.__wrapped__)
continue
if not isinstance(param.grad, Buffer):
raise TypeError(
"grad must be a Buffer, maybe you forget to call backward()?"
)
if not param.requires_grad:
continue continue
states = self._state[param] states = self._state[param]
...@@ -87,4 +78,3 @@ class Adagrad(Optimizer): ...@@ -87,4 +78,3 @@ class Adagrad(Optimizer):
clr = lr / (1 + (step - 1) * lr_decay) clr = lr / (1 + (step - 1) * lr_decay)
param -= clr * delta param -= clr * delta
assert len(self._grad_skip) == 0
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union
from ..tensor_nn import Buffer, Parameter from ..tensor_nn import Parameter
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -59,18 +59,9 @@ class Adam(Optimizer): ...@@ -59,18 +59,9 @@ class Adam(Optimizer):
for param in param_group["params"]: for param in param_group["params"]:
if param.__wrapped__ in self._grad_skip: if not param.requires_grad or "grad" not in param.__dict__:
self._grad_skip.remove(param.__wrapped__)
continue continue
if not param.requires_grad:
continue
if not isinstance(param.grad, Buffer):
raise TypeError(
"grad must be a Buffer, maybe you forget to call backward()?"
)
grad = param.grad grad = param.grad
if weight_decay != 0.0: if weight_decay != 0.0:
grad += param * weight_decay grad += param * weight_decay
...@@ -91,5 +82,3 @@ class Adam(Optimizer): ...@@ -91,5 +82,3 @@ class Adam(Optimizer):
# not inplace change, need to update underlying tensor handler in state # not inplace change, need to update underlying tensor handler in state
states["exp_avg"]._reset(exp_avg) states["exp_avg"]._reset(exp_avg)
states["exp_avg_sq"]._reset(exp_avg_sq) states["exp_avg_sq"]._reset(exp_avg_sq)
assert len(self._grad_skip) == 0
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Union from typing import Iterable, Union
from ..tensor_nn import Buffer, Parameter from ..tensor_nn import Parameter
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -52,7 +52,7 @@ class SGD(Optimizer): ...@@ -52,7 +52,7 @@ class SGD(Optimizer):
momentum = param_group["momentum"] momentum = param_group["momentum"]
for param in param_group["params"]: for param in param_group["params"]:
if not param.requires_grad: if not param.requires_grad or "grad" not in param.__dict__:
continue continue
grad = param.grad grad = param.grad
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import numpy as np import numpy as np
import megengine import megengine
import megengine.autodiff as ad
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.module import Module from megengine.module import Module
...@@ -37,8 +38,9 @@ class Simple2(Module): ...@@ -37,8 +38,9 @@ class Simple2(Module):
def test_advance_indexing(): def test_advance_indexing():
net = Simple() net = Simple()
gm = ad.GradManager().register(net.parameters())
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad() optim.clear_grad()
dshape = (10, 10) dshape = (10, 10)
raw_data = np.arange(100).reshape(dshape).astype(np.float32) raw_data = np.arange(100).reshape(dshape).astype(np.float32)
...@@ -46,9 +48,9 @@ def test_advance_indexing(): ...@@ -46,9 +48,9 @@ def test_advance_indexing():
data = tensor(raw_data) data = tensor(raw_data)
mask = tensor(raw_mask) mask = tensor(raw_mask)
answer = 1.0 - raw_data[raw_mask].sum() answer = 1.0 - raw_data[raw_mask].sum()
with optim.record(): with gm.record():
loss = net(data, mask).sum() loss = net(data, mask).sum()
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32)) np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32))
...@@ -56,15 +58,16 @@ def test_advance_indexing(): ...@@ -56,15 +58,16 @@ def test_advance_indexing():
def test_advance_indexing_with_subtensor(): def test_advance_indexing_with_subtensor():
net = Simple2() net = Simple2()
gm = ad.GradManager().register(net.parameters())
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad() optim.clear_grad()
dshape = (2, 3, 4, 3, 4, 2) dshape = (2, 3, 4, 3, 4, 2)
raw_data = np.arange(576).reshape(dshape).astype(np.float32) raw_data = np.arange(576).reshape(dshape).astype(np.float32)
data = tensor(raw_data) data = tensor(raw_data)
answer = 1.0 - raw_data[1, ..., :, 0:4:2, 0:2].sum() answer = 1.0 - raw_data[1, ..., :, 0:4:2, 0:2].sum()
with optim.record(): with gm.record():
loss = net(data).sum() loss = net(data).sum()
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32)) np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32))
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import numpy as np import numpy as np
import megengine import megengine
import megengine.autodiff as ad
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.module import Module from megengine.module import Module
...@@ -27,14 +28,15 @@ class Simple(Module): ...@@ -27,14 +28,15 @@ class Simple(Module):
def test_ai(): def test_ai():
net = Simple() net = Simple()
gm = ad.GradManager().register(net.parameters())
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad() optim.clear_grad()
dshape = (10, 10) dshape = (10, 10)
data = tensor(np.ones(dshape).astype(np.float32)) data = tensor(np.ones(dshape).astype(np.float32))
with optim.record(): with gm.record():
loss = net(data).sum() loss = net(data).sum()
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
net.a.numpy(), np.array([1.0 - dshape[0]]).astype(np.float32) net.a.numpy(), np.array([1.0 - dshape[0]]).astype(np.float32)
......
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
import pytest import pytest
import megengine import megengine
import megengine.autodiff as ad
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.module import BatchNorm2d from megengine.module import BatchNorm2d
...@@ -24,13 +25,14 @@ def test_frozen_bn(): ...@@ -24,13 +25,14 @@ def test_frozen_bn():
saved_wt = m.weight.numpy() saved_wt = m.weight.numpy()
saved_bias = m.bias.numpy() saved_bias = m.bias.numpy()
gm = ad.GradManager().register(m.parameters())
optim = optimizer.SGD(m.parameters(), lr=1.0) optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.zero_grad() optim.clear_grad()
data = np.random.random((6, nchannel, 2, 2)).astype("float32") data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with optim.record(): with gm.record():
loss = m(data).mean() loss = m(data).mean()
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_equal(m.running_var.numpy(), saved_var) np.testing.assert_equal(m.running_var.numpy(), saved_var)
...@@ -44,13 +46,14 @@ def test_bn_no_track_stat(): ...@@ -44,13 +46,14 @@ def test_bn_no_track_stat():
nchannel = 3 nchannel = 3
m = BatchNorm2d(nchannel, track_running_stats=False) m = BatchNorm2d(nchannel, track_running_stats=False)
gm = ad.GradManager().register(m.parameters())
optim = optimizer.SGD(m.parameters(), lr=1.0) optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.zero_grad() optim.clear_grad()
data = np.random.random((6, nchannel, 2, 2)).astype("float32") data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with optim.record(): with gm.record():
loss = m(data).sum() loss = m(data).sum()
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -65,13 +68,14 @@ def test_bn_no_track_stat2(): ...@@ -65,13 +68,14 @@ def test_bn_no_track_stat2():
saved_mean = m.running_mean.numpy() saved_mean = m.running_mean.numpy()
assert saved_mean is not None assert saved_mean is not None
gm = ad.GradManager().register(m.parameters())
optim = optimizer.SGD(m.parameters(), lr=1.0) optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.zero_grad() optim.clear_grad()
data = np.random.random((6, nchannel, 2, 2)).astype("float32") data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with optim.record(): with gm.record():
loss = m(data).sum() loss = m(data).sum()
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_equal(m.running_var.numpy(), saved_var) np.testing.assert_equal(m.running_var.numpy(), saved_var)
......
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
import pytest import pytest
import megengine as mge import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import Tensor from megengine import Tensor
from megengine.module import Linear, Module from megengine.module import Linear, Module
...@@ -76,12 +77,13 @@ def test_training_converge(): ...@@ -76,12 +77,13 @@ def test_training_converge():
opt = SGD( opt = SGD(
net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4
) )
gm = ad.GradManager().register(net.parameters())
def train(data, label): def train(data, label):
with opt.record(): with gm.record():
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label) loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss) gm.backward(loss)
return loss return loss
def infer(data): def infer(data):
...@@ -93,7 +95,7 @@ def test_training_converge(): ...@@ -93,7 +95,7 @@ def test_training_converge():
for data, label in itertools.islice(train_dataset, 2000): for data, label in itertools.islice(train_dataset, 2000):
data = Tensor(data, dtype=np.float32) data = Tensor(data, dtype=np.float32)
label = Tensor(label, dtype=np.int32) label = Tensor(label, dtype=np.int32)
opt.zero_grad() opt.clear_grad()
loss = train(data, label) loss = train(data, label)
opt.step() opt.step()
losses.append(loss.numpy()) losses.append(loss.numpy())
......
...@@ -15,6 +15,7 @@ import numpy as np ...@@ -15,6 +15,7 @@ import numpy as np
import pytest import pytest
import megengine as mge import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import jit from megengine import jit
from megengine.core._trace_option import set_tensor_shape from megengine.core._trace_option import set_tensor_shape
...@@ -89,11 +90,11 @@ class MnistNet(Module): ...@@ -89,11 +90,11 @@ class MnistNet(Module):
return x return x
def train(data, label, net, opt): def train(data, label, net, opt, gm):
with opt.record(): with gm.record():
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label) loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss) gm.backward(loss)
return loss return loss
...@@ -116,12 +117,13 @@ def update_model(model_path): ...@@ -116,12 +117,13 @@ def update_model(model_path):
net.load_state_dict(checkpoint["net_init"]) net.load_state_dict(checkpoint["net_init"])
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
opt = SGD(net.parameters(), lr=lr) opt = SGD(net.parameters(), lr=lr)
gm = ad.GradManager().register(net.parameters())
data = Tensor(checkpoint["data"], dtype=np.float32) data = Tensor(checkpoint["data"], dtype=np.float32)
label = Tensor(checkpoint["label"], dtype=np.int32) label = Tensor(checkpoint["label"], dtype=np.int32)
opt.zero_grad() opt.clear_grad()
loss = train(data, label, net=net, opt=opt) loss = train(data, label, net, opt, gm)
opt.step() opt.step()
xpu_name = get_xpu_name() xpu_name = get_xpu_name()
...@@ -150,6 +152,7 @@ def run_train( ...@@ -150,6 +152,7 @@ def run_train(
net.load_state_dict(checkpoint["net_init"]) net.load_state_dict(checkpoint["net_init"])
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
opt = SGD(net.parameters(), lr=lr) opt = SGD(net.parameters(), lr=lr)
gm = ad.GradManager().register(net.parameters())
data = Tensor(checkpoint["data"], dtype=np.float32) data = Tensor(checkpoint["data"], dtype=np.float32)
label = Tensor(checkpoint["label"], dtype=np.int32) label = Tensor(checkpoint["label"], dtype=np.int32)
...@@ -165,8 +168,8 @@ def run_train( ...@@ -165,8 +168,8 @@ def run_train(
sublinear_memory_config=sublinear_memory_config, sublinear_memory_config=sublinear_memory_config,
) )
opt.zero_grad() opt.clear_grad()
loss = train_func(data, label, net=net, opt=opt) loss = train_func(data, label, net, opt, gm)
opt.step() opt.step()
assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err)
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import numpy as np import numpy as np
import megengine import megengine
import megengine.autodiff as ad
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.module import Module from megengine.module import Module
...@@ -30,13 +31,14 @@ def test_detach(): ...@@ -30,13 +31,14 @@ def test_detach():
net = Simple() net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad() optim.clear_grad()
gm = ad.GradManager().register(net.parameters())
dshape = (10, 10) dshape = (10, 10)
data = tensor(np.ones(dshape).astype(np.float32)) data = tensor(np.ones(dshape).astype(np.float32))
with optim.record(): with gm.record():
loss = net(data).sum() loss = net(data).sum()
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_equal(net.a.numpy(), np.array([1.0]).astype(np.float32)) np.testing.assert_equal(net.a.numpy(), np.array([1.0]).astype(np.float32))
np.testing.assert_equal( np.testing.assert_equal(
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import pytest import pytest
import megengine as mge import megengine as mge
import megengine.autodiff as ad
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
from megengine.device import get_default_device, set_default_device from megengine.device import get_default_device, set_default_device
...@@ -94,11 +95,13 @@ class MnistNet(Module): ...@@ -94,11 +95,13 @@ class MnistNet(Module):
return x return x
def train(data, label, net, opt): def train(data, label, net, opt, gm):
with opt.record(): opt.clear_grad()
with gm.record():
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label) loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss) gm.backward(loss)
opt.step()
return loss return loss
...@@ -111,7 +114,7 @@ def update_model(model_path): ...@@ -111,7 +114,7 @@ def update_model(model_path):
.. code-block:: python .. code-block:: python
from test_correctness import update_model from test_dp_correctness import update_model
update_model('mnist_model_with_test.mge') # for gpu update_model('mnist_model_with_test.mge') # for gpu
update_model('mnist_model_with_test_cpu.mge') # for cpu update_model('mnist_model_with_test_cpu.mge') # for cpu
...@@ -122,6 +125,11 @@ def update_model(model_path): ...@@ -122,6 +125,11 @@ def update_model(model_path):
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
opt = SGD(net.parameters(), lr=lr) opt = SGD(net.parameters(), lr=lr)
gm = ad.GradManager()
gm.register(
net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)]
)
data = Tensor(checkpoint["data"], dtype=np.float32) data = Tensor(checkpoint["data"], dtype=np.float32)
label = Tensor(checkpoint["label"], dtype=np.int32) label = Tensor(checkpoint["label"], dtype=np.int32)
...@@ -158,24 +166,23 @@ def run_test( ...@@ -158,24 +166,23 @@ def run_test(
def worker(rank, max_err): def worker(rank, max_err):
dist.init_process_group("localhost", port, p_num, rank, rank) dist.init_process_group("localhost", port, p_num, rank, rank)
set_default_device(device="gpu{}".format(dist.get_rank()))
net = MnistNet(has_bn=True) net = MnistNet(has_bn=True)
net.load_state_dict(checkpoint["net_init"]) net.load_state_dict(checkpoint["net_init"])
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
opt = SGD(net.parameters(), reduce_method="mean", lr=lr) opt = SGD(net.parameters(), lr=lr)
gm = ad.GradManager()
gm.register(
net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)]
)
# use same data and label for all gpu's # use same data and label for all gpu's
# such that the result does not depend on number of gpu # such that the result does not depend on number of gpu
data_train = Tensor(data) data_train = Tensor(data)
label_train = Tensor(label) label_train = Tensor(label)
train_func = train loss = train(data_train, label_train, net, opt, gm)
opt.zero_grad()
loss = train_func(data_train, label_train, net=net, opt=opt)
opt.step()
print("{} loss {}".format(get_default_device(), loss.numpy()[0]))
assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err)
if dist.get_rank(): if dist.get_rank():
......
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
import pytest import pytest
import megengine import megengine
import megengine.autodiff as ad
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.module import Module from megengine.module import Module
...@@ -31,12 +32,13 @@ def test_hello_world(): ...@@ -31,12 +32,13 @@ def test_hello_world():
net = Simple() net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad() optim.clear_grad()
gm = ad.GradManager().register(net.parameters())
data = tensor([2.34]) data = tensor([2.34])
with optim.record(): with gm.record():
loss = net(data) loss = net(data)
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
net.a.numpy(), np.array([1.23 - 2.34]).astype(np.float32) net.a.numpy(), np.array([1.23 - 2.34]).astype(np.float32)
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import Parameter, optimizer from megengine import Parameter, optimizer
from megengine.jit import trace from megengine.jit import trace
...@@ -43,6 +44,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -43,6 +44,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
net = Simple() net = Simple()
opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
check_func = check_class(net, **test_case) check_func = check_class(net, **test_case)
gm = ad.GradManager().register(net.parameters())
step = 0 step = 0
data_shape = (2, 28) data_shape = (2, 28)
...@@ -54,11 +56,11 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -54,11 +56,11 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
check_func.lr += 0.01 check_func.lr += 0.01
data = tensor(np.random.random(data_shape).astype(np.float32)) data = tensor(np.random.random(data_shape).astype(np.float32))
opt.zero_grad() opt.clear_grad()
with opt.record(): with gm.record():
pred = net(data) pred = net(data)
loss = pred.sum() loss = pred.sum()
opt.backward(loss) gm.backward(loss)
ori_params = TensorDict() ori_params = TensorDict()
for param in net.parameters(): for param in net.parameters():
......
import numpy as np import numpy as np
import megengine as mge import megengine as mge
import megengine.autodiff as ad
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.core.tensor.raw_tensor import RawTensor from megengine.core.tensor.raw_tensor import RawTensor
...@@ -21,13 +22,14 @@ def test_save_load(): ...@@ -21,13 +22,14 @@ def test_save_load():
net = Simple() net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
optim.zero_grad() optim.clear_grad()
gm = ad.GradManager().register(net.parameters())
data = tensor([2.34]) data = tensor([2.34])
with optim.record(): with gm.record():
loss = net(data) loss = net(data)
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -53,9 +55,9 @@ def test_save_load(): ...@@ -53,9 +55,9 @@ def test_save_load():
optim.load_state_dict(checkpoint["opt_state"]) optim.load_state_dict(checkpoint["opt_state"])
print("load done") print("load done")
with optim.record(): with gm.record():
loss = net([1.23]) loss = net([1.23])
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
# Restore device # Restore device
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import numpy as np import numpy as np
import megengine import megengine
import megengine.autodiff as ad
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.jit import trace from megengine.jit import trace
...@@ -29,14 +30,15 @@ def test_sgd_momentum(): ...@@ -29,14 +30,15 @@ def test_sgd_momentum():
net = Simple() net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
optim.zero_grad() optim.clear_grad()
gm = ad.GradManager().register(net.parameters())
data = tensor([2.34]) data = tensor([2.34])
# do a step of train # do a step of train
with optim.record(): with gm.record():
loss = net(data) loss = net(data)
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34) np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34)
...@@ -48,10 +50,10 @@ def test_sgd_momentum(): ...@@ -48,10 +50,10 @@ def test_sgd_momentum():
np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34) np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34)
# do a step of train # do a step of train
optim.zero_grad() optim.clear_grad()
with optim.record(): with gm.record():
loss = net(data) loss = net(data)
optim.backward(loss) gm.backward(loss)
optim.step() optim.step()
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
......
...@@ -9,6 +9,7 @@ import copy ...@@ -9,6 +9,7 @@ import copy
import numpy as np import numpy as np
import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter from megengine import Parameter
...@@ -41,13 +42,14 @@ def test_single_input(): ...@@ -41,13 +42,14 @@ def test_single_input():
return x return x
net = Simple(av) net = Simple(av)
optim = optimizer.SGD(net.parameters(), lr=1.0) gm = ad.GradManager().register(net.parameters())
optim.zero_grad() opt = optimizer.SGD(net.parameters(), lr=1.0)
with optim.record(): opt.clear_grad()
with gm.record():
loss = net() loss = net()
optim.backward(loss.sum()) gm.backward(loss.sum())
optim.step() opt.step()
np.testing.assert_almost_equal(loss.numpy(), (av * 10)) np.testing.assert_almost_equal(loss.numpy(), (av * 10))
np.testing.assert_almost_equal(net.a.numpy(), (av - 10)) np.testing.assert_almost_equal(net.a.numpy(), (av - 10))
...@@ -79,13 +81,14 @@ def test_multi_input(): ...@@ -79,13 +81,14 @@ def test_multi_input():
return x return x
net = Simple(av, bv) net = Simple(av, bv)
optim = optimizer.SGD(net.parameters(), lr=1.0) gm = ad.GradManager().register(net.parameters())
optim.zero_grad() opt = optimizer.SGD(net.parameters(), lr=1.0)
with optim.record(): opt.clear_grad()
with gm.record():
loss = net() loss = net()
optim.backward(loss.sum()) gm.backward(loss.sum())
optim.step() opt.step()
np.testing.assert_almost_equal(loss.numpy(), (av * bv)) np.testing.assert_almost_equal(loss.numpy(), (av * bv))
np.testing.assert_almost_equal(net.a.numpy(), (av - 2 * bv)) np.testing.assert_almost_equal(net.a.numpy(), (av - 2 * bv))
...@@ -118,13 +121,14 @@ def test_multi_output(): ...@@ -118,13 +121,14 @@ def test_multi_output():
return x + y return x + y
net = Simple(av, bv) net = Simple(av, bv)
optim = optimizer.SGD(net.parameters(), lr=1.0) gm = ad.GradManager().register(net.parameters())
optim.zero_grad() opt = optimizer.SGD(net.parameters(), lr=1.0)
with optim.record(): opt.clear_grad()
with gm.record():
loss = net() loss = net()
optim.backward(loss.sum()) gm.backward(loss.sum())
optim.step() opt.step()
np.testing.assert_almost_equal(loss.numpy(), (av * bv + av + bv), decimal=6) np.testing.assert_almost_equal(loss.numpy(), (av * bv + av + bv), decimal=6)
np.testing.assert_almost_equal(net.a.numpy(), (av - bv - 1), decimal=6) np.testing.assert_almost_equal(net.a.numpy(), (av - bv - 1), decimal=6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册