提交 e82fa4ec 编写于 作者: M Megvii Engine Team

fix(gopt): using new_inp for build_chain in DelayBroadcast pass

GitOrigin-RevId: efc63771976a35647508b095015cb35a1e7f0c21
上级 a09fc5f7
......@@ -784,10 +784,10 @@ def sync_batch_norm(
if is_distributed():
# reduce all nodes' data to calculate mean and variance
reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
stat = concat(
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1
reduce_size = broadcast_to(
Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim
)
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
stat = all_reduce_sum(stat, group)
reduce_size = stat[:, :1].reshape(1)
channel_x1s = stat[:, 1 : 1 + _channels]
......
......@@ -18,6 +18,7 @@ from .core._wrap import device as as_device
from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device
from .logger import get_logger
from .utils.deprecation import deprecated
......@@ -41,6 +42,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
cn = device._cn
if isinstance(data, _Tensor):
if dtype is not None:
get_logger().warning(
"dtype does not work when creating a new Tensor with another Tensor"
)
obj = _Tensor.__new__(cls, data)
else:
if isinstance(data, np.ndarray):
......
......@@ -17,7 +17,7 @@ import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
from megengine.module import BatchNorm2d, Module, SyncBatchNorm
from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm
def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False):
......@@ -68,7 +68,7 @@ def test_frozen_bn():
run_frozen_bn(BatchNorm2d, True, True)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_frozen_synced_bn():
@dist.launcher(n_gpus=2)
......@@ -151,6 +151,45 @@ def test_trace_bn_forward_twice():
np.testing.assert_equal(y.numpy(), 0)
def run_syncbn(trace_mode):
x = F.ones([2, 16, 4, 4], dtype="float32")
net = Sequential(
Conv2d(16, 16, 1), SyncBatchNorm(16), Conv2d(16, 16, 1), SyncBatchNorm(16),
)
gm = ad.GradManager().attach(
net.parameters(), callbacks=dist.make_allreduce_cb("MEAN")
)
opt = optimizer.SGD(net.parameters(), 1e-3)
def train_func(x):
with gm:
y = net(x)
loss = y.mean()
gm.backward(loss)
opt.step().clear_grad()
return loss
if trace_mode is not None:
train_func = trace(train_func, symbolic=trace_mode)
for _ in range(3):
loss = train_func(x)
loss.numpy()
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
@pytest.mark.parametrize("trace_mode", [None, True, False])
def test_trace_several_syncbn(trace_mode):
@dist.launcher(n_gpus=2)
def worker():
run_syncbn(trace_mode)
worker()
# https://github.com/MegEngine/MegEngine/issues/145
def test_frozen_bn_no_affine():
nchannel = 3
......
......@@ -226,8 +226,14 @@ void DelayBroadcastPass::apply(OptState& opt) const {
if (!prev)
prev = rewriter.get_var(opr->input(inp_idx));
if (!opr->same_type<opr::Broadcast>()) {
VarNodeArray new_inp = opr->input();
new_inp.at(inp_idx) = prev;
VarNodeArray new_inp(opr->input().size());
for (size_t i = 0; i < opr->input().size(); i++) {
if (i == inp_idx) {
new_inp[i] = prev;
} else {
new_inp[i] = rewriter.get_var(opr->input(i));
}
}
opt.call_with_opr(opr, [&] {
// create new opr with the original opr's properties
auto new_opr = serialization::copy_opr_shallow(
......
......@@ -177,6 +177,32 @@ TEST_PASS(DelayBroadcastPass, LongChain) {
ASSERT_EQ(bcast(bcast(relu(relu(x)), y), z), out);
}
TEST_PASS(DelayBroadcastPass, ElemwiseChain) {
auto typecvt = [](SymbolVar x) {
return opr::TypeCvt::make(x, dtype::Int32());
};
auto reduce = [](SymbolVar x) {
SymbolVar tshp = x.make_scalar(1);
opr::Reduce::Param param_default{opr::Reduce::Mode::SUM, INT_MAX,
opr::Reduce::Param::DataType::DEFAULT};
return opr::Reduce::make(x, param_default, tshp);
};
auto shp = TensorShape{2, 2};
auto x = mkvar("x", {1, 1});
auto val = x.make_scalar(3);
auto out = reduce(typecvt(x.broadcast(shp))) + val.broadcast(shp);
out = gopt::GraphOptimizer{}.
add_pass<gopt::DelayBroadcastPass>().
apply({{out}}).endpoint_vars()[0];
auto expected = (reduce(typecvt(x).broadcast(shp)) + val).broadcast(shp);
ASSERT_EQ(out, expected);
}
TEST_PASS(ExpandVirtualGradPass, Simple) {
auto x = mkvar("x");
check(x * 2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册