From e82fa4ec23b009f7a591cb932cc50473189d33cb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 29 Jan 2021 17:33:39 +0800 Subject: [PATCH] fix(gopt): using new_inp for build_chain in DelayBroadcast pass GitOrigin-RevId: efc63771976a35647508b095015cb35a1e7f0c21 --- imperative/python/megengine/functional/nn.py | 6 +-- imperative/python/megengine/tensor.py | 5 +++ imperative/python/test/integration/test_bn.py | 43 ++++++++++++++++++- src/gopt/impl/misc.cpp | 10 ++++- src/gopt/test/misc.cpp | 26 +++++++++++ 5 files changed, 83 insertions(+), 7 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 9c8886ddf..c629bc70d 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -784,10 +784,10 @@ def sync_batch_norm( if is_distributed(): # reduce all nodes' data to calculate mean and variance - reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) - stat = concat( - [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 + reduce_size = broadcast_to( + Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim ) + stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1) stat = all_reduce_sum(stat, group) reduce_size = stat[:, :1].reshape(1) channel_x1s = stat[:, 1 : 1 + _channels] diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 43b0c5bae..3c7f909a2 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -18,6 +18,7 @@ from .core._wrap import device as as_device from .core.ops.builtin import Copy, GetVarShape from .core.tensor.array_method import ArrayMethodMixin from .device import _valid_device, get_default_device +from .logger import get_logger from .utils.deprecation import deprecated @@ -41,6 +42,10 @@ class Tensor(_Tensor, ArrayMethodMixin): cn = device._cn if isinstance(data, _Tensor): + if dtype is not None: + get_logger().warning( + "dtype does not work when creating a new Tensor with another Tensor" + ) obj = _Tensor.__new__(cls, data) else: if isinstance(data, np.ndarray): diff --git a/imperative/python/test/integration/test_bn.py b/imperative/python/test/integration/test_bn.py index 6d351408d..38816f2bd 100644 --- a/imperative/python/test/integration/test_bn.py +++ b/imperative/python/test/integration/test_bn.py @@ -17,7 +17,7 @@ import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.distributed.helper import get_device_count_by_fork from megengine.jit import trace -from megengine.module import BatchNorm2d, Module, SyncBatchNorm +from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): @@ -68,7 +68,7 @@ def test_frozen_bn(): run_frozen_bn(BatchNorm2d, True, True) -@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") +@pytest.mark.require_ngpu(2) @pytest.mark.isolated_distributed def test_frozen_synced_bn(): @dist.launcher(n_gpus=2) @@ -151,6 +151,45 @@ def test_trace_bn_forward_twice(): np.testing.assert_equal(y.numpy(), 0) +def run_syncbn(trace_mode): + x = F.ones([2, 16, 4, 4], dtype="float32") + + net = Sequential( + Conv2d(16, 16, 1), SyncBatchNorm(16), Conv2d(16, 16, 1), SyncBatchNorm(16), + ) + + gm = ad.GradManager().attach( + net.parameters(), callbacks=dist.make_allreduce_cb("MEAN") + ) + opt = optimizer.SGD(net.parameters(), 1e-3) + + def train_func(x): + with gm: + y = net(x) + loss = y.mean() + gm.backward(loss) + opt.step().clear_grad() + return loss + + if trace_mode is not None: + train_func = trace(train_func, symbolic=trace_mode) + + for _ in range(3): + loss = train_func(x) + loss.numpy() + + +@pytest.mark.require_ngpu(2) +@pytest.mark.isolated_distributed +@pytest.mark.parametrize("trace_mode", [None, True, False]) +def test_trace_several_syncbn(trace_mode): + @dist.launcher(n_gpus=2) + def worker(): + run_syncbn(trace_mode) + + worker() + + # https://github.com/MegEngine/MegEngine/issues/145 def test_frozen_bn_no_affine(): nchannel = 3 diff --git a/src/gopt/impl/misc.cpp b/src/gopt/impl/misc.cpp index c84b5c450..a98ec972d 100644 --- a/src/gopt/impl/misc.cpp +++ b/src/gopt/impl/misc.cpp @@ -226,8 +226,14 @@ void DelayBroadcastPass::apply(OptState& opt) const { if (!prev) prev = rewriter.get_var(opr->input(inp_idx)); if (!opr->same_type()) { - VarNodeArray new_inp = opr->input(); - new_inp.at(inp_idx) = prev; + VarNodeArray new_inp(opr->input().size()); + for (size_t i = 0; i < opr->input().size(); i++) { + if (i == inp_idx) { + new_inp[i] = prev; + } else { + new_inp[i] = rewriter.get_var(opr->input(i)); + } + } opt.call_with_opr(opr, [&] { // create new opr with the original opr's properties auto new_opr = serialization::copy_opr_shallow( diff --git a/src/gopt/test/misc.cpp b/src/gopt/test/misc.cpp index ac38e9382..c85723b8b 100644 --- a/src/gopt/test/misc.cpp +++ b/src/gopt/test/misc.cpp @@ -177,6 +177,32 @@ TEST_PASS(DelayBroadcastPass, LongChain) { ASSERT_EQ(bcast(bcast(relu(relu(x)), y), z), out); } +TEST_PASS(DelayBroadcastPass, ElemwiseChain) { + auto typecvt = [](SymbolVar x) { + return opr::TypeCvt::make(x, dtype::Int32()); + }; + + auto reduce = [](SymbolVar x) { + SymbolVar tshp = x.make_scalar(1); + opr::Reduce::Param param_default{opr::Reduce::Mode::SUM, INT_MAX, + opr::Reduce::Param::DataType::DEFAULT}; + return opr::Reduce::make(x, param_default, tshp); + }; + + auto shp = TensorShape{2, 2}; + + auto x = mkvar("x", {1, 1}); + auto val = x.make_scalar(3); + + auto out = reduce(typecvt(x.broadcast(shp))) + val.broadcast(shp); + out = gopt::GraphOptimizer{}. + add_pass(). + apply({{out}}).endpoint_vars()[0]; + + auto expected = (reduce(typecvt(x).broadcast(shp)) + val).broadcast(shp); + ASSERT_EQ(out, expected); +} + TEST_PASS(ExpandVirtualGradPass, Simple) { auto x = mkvar("x"); check(x * 2, -- GitLab