提交 8c47c1f1 编写于 作者: M Megvii Engine Team

perf(syncbn): reimplement with subgraph

GitOrigin-RevId: 13e7e3d3c0d0e9cd8939ad5ddf62bc91a5dabde0
上级 53da5c79
......@@ -13,6 +13,7 @@ import numpy as np
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._wrap import as_device
from ..ops import builtin
from ..ops.special import Const
......@@ -219,3 +220,49 @@ def _normalize_axis(
)
return axis
raise
def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
if device.physical_name.startswith("cpu"):
gopt_level = None # disable jit and compile
binary_ops = {
"+": builtin.Elemwise(mode="add"),
"-": builtin.Elemwise(mode="sub"),
"*": builtin.Elemwise(mode="mul"),
"/": builtin.Elemwise(mode="true_div"),
"//": builtin.Elemwise(mode="floor_div"),
"**": builtin.Elemwise(mode="pow"),
"√": builtin.Elemwise(mode="expm1"),
"max": builtin.Elemwise(mode="max"),
"additive": builtin.Elemwise(mode="add"),
}
unary_ops = {
"-": builtin.Elemwise(mode="negate"),
}
def decorator(func):
builder = _SubgraphBuilder(name)
def apply_expr(op, *args):
if isinstance(op, str):
if len(args) == 2:
op = binary_ops[op]
elif len(args) == 1:
op = unary_ops[op]
return builder.apply(op, args, 1)[0]
def apply_const(value, dtype=dtype, device=device):
return builder.apply_const(value, dtype, device)
inputs = [builder.input() for _ in range(nr_inputs)]
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const)
builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad)
if gopt_level is None:
return builder.get()
else:
return builder.compile(gopt_level)
return decorator
......@@ -7,11 +7,13 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines
from typing import Optional, Sequence, Tuple, Union
from functools import lru_cache
from typing import NamedTuple, Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const
from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply
......@@ -20,10 +22,13 @@ from ..core.tensor.utils import (
astype,
cast_tensors,
convert_single_value,
make_shape_tuple,
setscalar,
subgraph,
)
from ..device import get_default_device
from ..distributed import WORLD, is_distributed
from ..jit import exclude_from_trace
from ..random import uniform
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
......@@ -1153,6 +1158,111 @@ def batch_norm(
return inp
@lru_cache(maxsize=None)
def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
# fmt: off
@subgraph("SyncBnStage0", dtype, device, 1)
def syncbn_stage0(inputs, f, c):
input = inputs[0]
reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device)
input_shape = f(GetVarShape(), input)
input_elems = f(Reduce(mode="product", axis=0), input_shape)
reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape)
reduce_size = f("//", input_elems, reduce_elems)
channel_x1s = f(Reduce(mode="sum"), input, reduce_shape)
channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)
@subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3)
def syncbn_stage1(inputs, f, c):
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
weight, bias = inputs[5:7]
channel_mean = f("/", channel_x1s, reduce_size)
channel_var =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("-", f("*", reduce_size, reduce_size))),
f("/", channel_x2s, reduce_size))
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("+", f("*", input, inv_var_wt),
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False)
@subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3)
def syncbn_stage1_inference(inputs, f, c):
input, channel_mean, channel_var, eps = inputs[0:4]
weight, bias = inputs[4:6]
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("+", f("*", input, inv_var_wt),
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar,), (True,)
@subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3)
def syncbn_stage2(inputs, f, c):
running_mean, running_var, momentum = inputs[0:3]
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
running_mean = f("*", running_mean, momentum)
running_mean =\
f("+", running_mean,
f("*", f("-", c(1), momentum),
channel_mean))
channel_variance_unbiased =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("*", f("-", reduce_size),
f("-", reduce_size, c(1)))),
f("/", channel_x2s,
f("-", reduce_size, c(1))))
running_var = f("*", running_var, momentum)
running_var =\
f("+", running_var,
f("*", f("-", c(1), momentum),
channel_variance_unbiased))
return (running_mean, running_var), (True, True)
@subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3)
def syncbn_concat_stats(inputs, f, c):
reduce_size, channel_x1s, channel_x2s = inputs[0:3]
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
return (stats,), (True,)
@subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3)
def syncbn_split_stats(inputs, f, c):
stats = inputs[0]
c_1 = c(1, dtype="int32")
channel_x1s_end = c(channels+1, dtype="int32")
def _subtensor(src, axis, begin, end):
items = (axis, (begin is not None), (end is not None), False, False),
args = ()
if begin is not None:
args += begin,
if end is not None:
args += end,
return f(builtin.Subtensor(items=items), src, *args)
reduce_size = _subtensor(stats, 1, None, c_1)
channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end)
channel_x2s = _subtensor(stats, 1, channel_x1s_end, None)
reduce_size = f(builtin.Reshape(), reduce_size, c_1)
return (reduce_size, channel_x1s, channel_x2s), (False, True, True)
# fmt: on
return (
syncbn_stage0,
syncbn_stage1,
syncbn_stage1_inference,
syncbn_stage2,
syncbn_concat_stats,
syncbn_split_stats,
)
def sync_batch_norm(
inp: Tensor,
running_mean: Tensor,
......@@ -1193,52 +1303,55 @@ def sync_batch_norm(
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
eps_mode
)
_channels = inp.shape[1]
# TODO: cudnnBn fastpath
_channels = make_shape_tuple(inp.shape)[1]
_ndim = inp.ndim
_device = inp.device
_dtype = inp.dtype
_param_shape = (1, _channels) + (1,) * (_ndim - 2)
_reduce_axis = [0] + [i for i in range(2, _ndim)]
if training:
def _make_full_if_none(x, value):
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=_device)()
(result,) = apply(builtin.Broadcast(), x, reduce_shape)
return result
elif x.ndim == 1:
(result,) = apply(builtin.Reshape(), x, reduce_shape)
return result
return x
(
syncbn_stage0,
syncbn_stage1,
syncbn_stage1_inference,
syncbn_stage2,
syncbn_concat_stats,
syncbn_split_stats,
) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels)
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp)
def _sum_on_channel(inp):
return inp.sum(axis=_reduce_axis, keepdims=True)
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
reduce_size = inp.shape[0]
for i in range(2, _ndim):
reduce_size = reduce_size * inp.shape[i]
channel_x1s = _sum_on_channel(inp)
channel_x2s = _sum_on_channel(inp ** 2)
weight = _make_full_if_none(weight, 1)
bias = _make_full_if_none(bias, 0)
if training:
if is_distributed():
# reduce all nodes' data to calculate mean and variance
reduce_size = broadcast_to(
Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim
)
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
(stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s)
stat = all_reduce_sum(stat, group)
reduce_size = stat[:, :1].reshape(1)
channel_x1s = stat[:, 1 : 1 + _channels]
channel_x2s = stat[:, 1 + _channels :]
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat)
channel_mean = channel_x1s / reduce_size
channel_variance = (
channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size
outvar, channel_mean, *_ = apply(
syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias
)
else:
assert running_var is not None and running_mean is not None
channel_variance = running_var.reshape(*_param_shape)
channel_mean = running_mean.reshape(*_param_shape)
invsqrt_channel_variance = (
maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps
) ** -0.5
if weight is not None:
weight = weight.reshape(*_param_shape)
if bias is not None:
bias = bias.reshape(*_param_shape)
channel_mean = running_mean
channel_var = running_var
outvar, *_ = apply(
syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias
)
# outvar = output * weight + bias
# where output = inp * invsqrt_channel_variance + (
......@@ -1246,28 +1359,18 @@ def sync_batch_norm(
# )
# Manually expand output for gopt
if weight is not None:
inv_var_wt = invsqrt_channel_variance * weight
neg_channel_mean = -channel_mean
if bias is not None:
outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
else:
outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
else:
outvar = inp * invsqrt_channel_variance + (
-channel_mean * invsqrt_channel_variance
)
if bias is not None:
outvar = outvar + bias
if training and running_var is not None and running_mean is not None:
running_mean *= momentum
running_mean += (1 - momentum) * channel_mean
channel_variance_unbiased = channel_x1s ** 2 / (
-reduce_size * (reduce_size - 1)
) + channel_x2s / (reduce_size - 1)
running_var *= momentum
running_var += (1 - momentum) * channel_variance_unbiased
momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device)
running_mean[...], running_var[...] = apply(
syncbn_stage2,
running_mean,
running_var,
momentum,
reduce_size,
channel_x1s,
channel_x2s,
channel_mean,
)
return outvar
......
......@@ -66,7 +66,7 @@ def is_tracing():
@contextlib.contextmanager
def exclude_from_trace():
global skip_tracing
if skip_tracing:
if skip_tracing or (active_trace is None):
yield
return
try:
......
......@@ -58,6 +58,9 @@ void init_common(py::module m) {
.def_property_readonly("logical_name", [](const CompNode& cn) {
return cn.to_string_logical();
})
.def_property_readonly("physical_name", [](const CompNode& cn) {
return cn.to_string();
})
.def_property_readonly("get_mem_status_bytes", [](const CompNode& cn) {
return cn.get_mem_status_bytes();
})
......@@ -70,6 +73,7 @@ void init_common(py::module m) {
cn.to_string_physical().c_str(),
cn.to_string_logical().c_str());
})
.def("__hash__", [](CompNode cn){ return mgb::hash(cn); })
.def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self)
.def_static("_get_device_count", &CompNode::get_device_count,
......
......@@ -15,6 +15,7 @@
#include "megbrain/common.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
......@@ -477,4 +478,50 @@ void init_ops(py::module m) {
m.def("set_global_rng_seed", &rng::set_global_rng_seed);
m.def("get_global_rng_seed", &rng::get_global_rng_seed);
m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
struct PySubgraphBuilder {
explicit PySubgraphBuilder(std::string name) : name{name}{}
std::string name;
Subgraph graph;
mgb::SmallVector<bool> output_grad_mask;
Subgraph::var_t next_var = 1;
};
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
.def(py::init<std::string>())
.def("input", [](PySubgraphBuilder& self){
auto var = self.next_var++;
self.graph.inputs.push_back(var);
return var;
})
.def("apply", [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, Subgraph::vars_t inputs, size_t nr_outputs){
Subgraph::vars_t outputs;
for (size_t i = 0; i < nr_outputs; ++i) {
outputs.push_back(self.next_var++);
}
self.graph.exprs.push_back({op, inputs, outputs});
return outputs;
})
.def("apply_const", [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, mgb::CompNode cn){
auto var = self.next_var++;
mgb::HostTensorND hvalue(cn);
npy::np2tensor(value.cast<py::array>().ptr(), npy::Meth::copy_into(&hvalue), dtype);
self.graph.constants.push_back({var, Tensor::make(hvalue)});
return var;
})
.def("outputs", [](PySubgraphBuilder& self, Subgraph::vars_t outputs){
self.graph.outputs = outputs;
self.output_grad_mask.resize(outputs.size(), true);
})
.def("outputs_has_grad", [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad){
mgb_assert(self.graph.outputs.size() == self.output_grad_mask.size());
self.output_grad_mask = outputs_has_grad;
})
.def("get", [](PySubgraphBuilder& self){
return (std::shared_ptr<OpDef>)SubgraphOp::make(self.name, self.graph, self.output_grad_mask);
})
.def("compile", [](PySubgraphBuilder& self, int gopt_level){
auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask);
return (std::shared_ptr<OpDef>)CompiledOp::make(op, gopt_level);
});
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册