From 8c47c1f14942ab473b736c133a3d0353858bacda Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 2 Aug 2021 17:29:50 +0800 Subject: [PATCH] perf(syncbn): reimplement with subgraph GitOrigin-RevId: 13e7e3d3c0d0e9cd8939ad5ddf62bc91a5dabde0 --- .../python/megengine/core/tensor/utils.py | 47 ++++ imperative/python/megengine/functional/nn.py | 213 +++++++++++++----- imperative/python/megengine/jit/tracing.py | 2 +- imperative/python/src/common.cpp | 4 + imperative/python/src/ops.cpp | 47 ++++ 5 files changed, 257 insertions(+), 56 deletions(-) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 201366a10..816e8a94a 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -13,6 +13,7 @@ import numpy as np from .._imperative_rt import make_const from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device +from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from .._wrap import as_device from ..ops import builtin from ..ops.special import Const @@ -219,3 +220,49 @@ def _normalize_axis( ) return axis raise + + +def subgraph(name, dtype, device, nr_inputs, gopt_level=None): + if device.physical_name.startswith("cpu"): + gopt_level = None # disable jit and compile + + binary_ops = { + "+": builtin.Elemwise(mode="add"), + "-": builtin.Elemwise(mode="sub"), + "*": builtin.Elemwise(mode="mul"), + "/": builtin.Elemwise(mode="true_div"), + "//": builtin.Elemwise(mode="floor_div"), + "**": builtin.Elemwise(mode="pow"), + "√": builtin.Elemwise(mode="expm1"), + "max": builtin.Elemwise(mode="max"), + "additive": builtin.Elemwise(mode="add"), + } + + unary_ops = { + "-": builtin.Elemwise(mode="negate"), + } + + def decorator(func): + builder = _SubgraphBuilder(name) + + def apply_expr(op, *args): + if isinstance(op, str): + if len(args) == 2: + op = binary_ops[op] + elif len(args) == 1: + op = unary_ops[op] + return builder.apply(op, args, 1)[0] + + def apply_const(value, dtype=dtype, device=device): + return builder.apply_const(value, dtype, device) + + inputs = [builder.input() for _ in range(nr_inputs)] + outputs, outputs_has_grad = func(inputs, apply_expr, apply_const) + builder.outputs(outputs) + builder.outputs_has_grad(outputs_has_grad) + if gopt_level is None: + return builder.get() + else: + return builder.compile(gopt_level) + + return decorator diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 72078a7b5..d7c7a64b8 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -7,11 +7,13 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # pylint: disable=too-many-lines -from typing import Optional, Sequence, Tuple, Union +from functools import lru_cache +from typing import NamedTuple, Optional, Sequence, Tuple, Union from ..core._imperative_rt.core2 import apply, dtype_promotion +from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin -from ..core.ops.builtin import BatchNorm, Elemwise +from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt from ..core.ops.special import Const from ..core.tensor import amp, megbrain_graph from ..core.tensor.array_method import _elwise_apply @@ -20,10 +22,13 @@ from ..core.tensor.utils import ( astype, cast_tensors, convert_single_value, + make_shape_tuple, setscalar, + subgraph, ) from ..device import get_default_device from ..distributed import WORLD, is_distributed +from ..jit import exclude_from_trace from ..random import uniform from ..tensor import Tensor from ..utils.deprecation import deprecated_func @@ -1153,6 +1158,111 @@ def batch_norm( return inp +@lru_cache(maxsize=None) +def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): + # fmt: off + @subgraph("SyncBnStage0", dtype, device, 1) + def syncbn_stage0(inputs, f, c): + input = inputs[0] + reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device) + input_shape = f(GetVarShape(), input) + input_elems = f(Reduce(mode="product", axis=0), input_shape) + reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape) + reduce_size = f("//", input_elems, reduce_elems) + channel_x1s = f(Reduce(mode="sum"), input, reduce_shape) + channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape) + reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) + return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True) + + @subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3) + def syncbn_stage1(inputs, f, c): + input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5] + weight, bias = inputs[5:7] + channel_mean = f("/", channel_x1s, reduce_size) + channel_var =\ + f("+", f("/", f("**", channel_x1s, c(2)), + f("-", f("*", reduce_size, reduce_size))), + f("/", channel_x2s, reduce_size)) + invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5)) + inv_var_wt = f("*", invsqrt_channel_var, weight) + neg_channel_mean = f("-", channel_mean) + outvar =\ + f("+", f("*", input, inv_var_wt), + f("+", f("*", neg_channel_mean, inv_var_wt), + bias)) + return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) + + @subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3) + def syncbn_stage1_inference(inputs, f, c): + input, channel_mean, channel_var, eps = inputs[0:4] + weight, bias = inputs[4:6] + invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5)) + inv_var_wt = f("*", invsqrt_channel_var, weight) + neg_channel_mean = f("-", channel_mean) + outvar =\ + f("+", f("*", input, inv_var_wt), + f("+", f("*", neg_channel_mean, inv_var_wt), + bias)) + return (outvar,), (True,) + + @subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3) + def syncbn_stage2(inputs, f, c): + running_mean, running_var, momentum = inputs[0:3] + reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7] + running_mean = f("*", running_mean, momentum) + running_mean =\ + f("+", running_mean, + f("*", f("-", c(1), momentum), + channel_mean)) + channel_variance_unbiased =\ + f("+", f("/", f("**", channel_x1s, c(2)), + f("*", f("-", reduce_size), + f("-", reduce_size, c(1)))), + f("/", channel_x2s, + f("-", reduce_size, c(1)))) + running_var = f("*", running_var, momentum) + running_var =\ + f("+", running_var, + f("*", f("-", c(1), momentum), + channel_variance_unbiased)) + return (running_mean, running_var), (True, True) + + @subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3) + def syncbn_concat_stats(inputs, f, c): + reduce_size, channel_x1s, channel_x2s = inputs[0:3] + reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32")) + stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s) + return (stats,), (True,) + + @subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3) + def syncbn_split_stats(inputs, f, c): + stats = inputs[0] + c_1 = c(1, dtype="int32") + channel_x1s_end = c(channels+1, dtype="int32") + def _subtensor(src, axis, begin, end): + items = (axis, (begin is not None), (end is not None), False, False), + args = () + if begin is not None: + args += begin, + if end is not None: + args += end, + return f(builtin.Subtensor(items=items), src, *args) + reduce_size = _subtensor(stats, 1, None, c_1) + channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end) + channel_x2s = _subtensor(stats, 1, channel_x1s_end, None) + reduce_size = f(builtin.Reshape(), reduce_size, c_1) + return (reduce_size, channel_x1s, channel_x2s), (False, True, True) + # fmt: on + return ( + syncbn_stage0, + syncbn_stage1, + syncbn_stage1_inference, + syncbn_stage2, + syncbn_concat_stats, + syncbn_split_stats, + ) + + def sync_batch_norm( inp: Tensor, running_mean: Tensor, @@ -1193,52 +1303,55 @@ def sync_batch_norm( assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( eps_mode ) - _channels = inp.shape[1] + # TODO: cudnnBn fastpath + _channels = make_shape_tuple(inp.shape)[1] _ndim = inp.ndim _device = inp.device _dtype = inp.dtype - _param_shape = (1, _channels) + (1,) * (_ndim - 2) - _reduce_axis = [0] + [i for i in range(2, _ndim)] - if training: + def _make_full_if_none(x, value): + if x is None: + (x,) = Const(value, dtype=inp.dtype, device=_device)() + (result,) = apply(builtin.Broadcast(), x, reduce_shape) + return result + elif x.ndim == 1: + (result,) = apply(builtin.Reshape(), x, reduce_shape) + return result + return x + + ( + syncbn_stage0, + syncbn_stage1, + syncbn_stage1_inference, + syncbn_stage2, + syncbn_concat_stats, + syncbn_split_stats, + ) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels) + + reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp) - def _sum_on_channel(inp): - return inp.sum(axis=_reduce_axis, keepdims=True) + eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) - reduce_size = inp.shape[0] - for i in range(2, _ndim): - reduce_size = reduce_size * inp.shape[i] - channel_x1s = _sum_on_channel(inp) - channel_x2s = _sum_on_channel(inp ** 2) + weight = _make_full_if_none(weight, 1) + bias = _make_full_if_none(bias, 0) + if training: if is_distributed(): # reduce all nodes' data to calculate mean and variance - reduce_size = broadcast_to( - Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim - ) - stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1) + (stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s) stat = all_reduce_sum(stat, group) - reduce_size = stat[:, :1].reshape(1) - channel_x1s = stat[:, 1 : 1 + _channels] - channel_x2s = stat[:, 1 + _channels :] + reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat) - channel_mean = channel_x1s / reduce_size - channel_variance = ( - channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size + outvar, channel_mean, *_ = apply( + syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias ) else: assert running_var is not None and running_mean is not None - channel_variance = running_var.reshape(*_param_shape) - channel_mean = running_mean.reshape(*_param_shape) - - invsqrt_channel_variance = ( - maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps - ) ** -0.5 - - if weight is not None: - weight = weight.reshape(*_param_shape) - if bias is not None: - bias = bias.reshape(*_param_shape) + channel_mean = running_mean + channel_var = running_var + outvar, *_ = apply( + syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias + ) # outvar = output * weight + bias # where output = inp * invsqrt_channel_variance + ( @@ -1246,28 +1359,18 @@ def sync_batch_norm( # ) # Manually expand output for gopt - if weight is not None: - inv_var_wt = invsqrt_channel_variance * weight - neg_channel_mean = -channel_mean - if bias is not None: - outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) - else: - outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt - else: - outvar = inp * invsqrt_channel_variance + ( - -channel_mean * invsqrt_channel_variance - ) - if bias is not None: - outvar = outvar + bias - if training and running_var is not None and running_mean is not None: - running_mean *= momentum - running_mean += (1 - momentum) * channel_mean - channel_variance_unbiased = channel_x1s ** 2 / ( - -reduce_size * (reduce_size - 1) - ) + channel_x2s / (reduce_size - 1) - running_var *= momentum - running_var += (1 - momentum) * channel_variance_unbiased + momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device) + running_mean[...], running_var[...] = apply( + syncbn_stage2, + running_mean, + running_var, + momentum, + reduce_size, + channel_x1s, + channel_x2s, + channel_mean, + ) return outvar diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 53bd63387..57a596248 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -66,7 +66,7 @@ def is_tracing(): @contextlib.contextmanager def exclude_from_trace(): global skip_tracing - if skip_tracing: + if skip_tracing or (active_trace is None): yield return try: diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 32a1cb68c..b3704b747 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -58,6 +58,9 @@ void init_common(py::module m) { .def_property_readonly("logical_name", [](const CompNode& cn) { return cn.to_string_logical(); }) + .def_property_readonly("physical_name", [](const CompNode& cn) { + return cn.to_string(); + }) .def_property_readonly("get_mem_status_bytes", [](const CompNode& cn) { return cn.get_mem_status_bytes(); }) @@ -70,6 +73,7 @@ void init_common(py::module m) { cn.to_string_physical().c_str(), cn.to_string_logical().c_str()); }) + .def("__hash__", [](CompNode cn){ return mgb::hash(cn); }) .def_static("_sync_all", &CompNode::sync_all) .def(py::self == py::self) .def_static("_get_device_count", &CompNode::get_device_count, diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 840a617cd..9ca8c812b 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -15,6 +15,7 @@ #include "megbrain/common.h" #include "megbrain/imperative.h" +#include "megbrain/imperative/graph_builder.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/utility.h" @@ -477,4 +478,50 @@ void init_ops(py::module m) { m.def("set_global_rng_seed", &rng::set_global_rng_seed); m.def("get_global_rng_seed", &rng::get_global_rng_seed); m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode); + + struct PySubgraphBuilder { + explicit PySubgraphBuilder(std::string name) : name{name}{} + std::string name; + Subgraph graph; + mgb::SmallVector output_grad_mask; + Subgraph::var_t next_var = 1; + }; + + py::class_(m, "SubgraphBuilder") + .def(py::init()) + .def("input", [](PySubgraphBuilder& self){ + auto var = self.next_var++; + self.graph.inputs.push_back(var); + return var; + }) + .def("apply", [](PySubgraphBuilder& self, std::shared_ptr op, Subgraph::vars_t inputs, size_t nr_outputs){ + Subgraph::vars_t outputs; + for (size_t i = 0; i < nr_outputs; ++i) { + outputs.push_back(self.next_var++); + } + self.graph.exprs.push_back({op, inputs, outputs}); + return outputs; + }) + .def("apply_const", [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, mgb::CompNode cn){ + auto var = self.next_var++; + mgb::HostTensorND hvalue(cn); + npy::np2tensor(value.cast().ptr(), npy::Meth::copy_into(&hvalue), dtype); + self.graph.constants.push_back({var, Tensor::make(hvalue)}); + return var; + }) + .def("outputs", [](PySubgraphBuilder& self, Subgraph::vars_t outputs){ + self.graph.outputs = outputs; + self.output_grad_mask.resize(outputs.size(), true); + }) + .def("outputs_has_grad", [](PySubgraphBuilder& self, mgb::SmallVector outputs_has_grad){ + mgb_assert(self.graph.outputs.size() == self.output_grad_mask.size()); + self.output_grad_mask = outputs_has_grad; + }) + .def("get", [](PySubgraphBuilder& self){ + return (std::shared_ptr)SubgraphOp::make(self.name, self.graph, self.output_grad_mask); + }) + .def("compile", [](PySubgraphBuilder& self, int gopt_level){ + auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask); + return (std::shared_ptr)CompiledOp::make(op, gopt_level); + }); } -- GitLab