From a4327c4d25ceacbfa0d73f0c03c97dd1f08ae864 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 17 Mar 2022 14:13:11 +0800 Subject: [PATCH] perf(imperative): add dim_expansion transform for conv/bn1d GitOrigin-RevId: d14a69424d9e15ac9ae29a9e4bcd9b532dc76200 --- imperative/python/megengine/functional/nn.py | 64 ++++++------- .../python/megengine/module/batchnorm.py | 16 ---- imperative/python/src/tensor.cpp | 6 ++ imperative/python/src/transformation.h | 3 +- .../python/test/integration/test_dtr.py | 2 +- .../impl/transformations/dim_expansion.cpp | 95 +++++++++++++++++++ .../transformations/dim_expansion.h | 19 ++++ 7 files changed, 153 insertions(+), 52 deletions(-) create mode 100644 imperative/src/impl/transformations/dim_expansion.cpp create mode 100644 imperative/src/include/megbrain/imperative/transformations/dim_expansion.h diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index a0fb5708e..346766eea 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -41,7 +41,6 @@ from ..distributed import WORLD, is_distributed from ..jit import exclude_from_trace from ..tensor import Tensor from ..utils.deprecation import deprecated_func -from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero from .debug_param import get_execution_strategy from .distributed import all_reduce_sum from .elemwise import _elwise, exp, log, log1p, maximum, minimum @@ -94,14 +93,15 @@ __all__ = [ def expand_hw(x): - # NOTE: >1d array is accepted, as long as 1 <= size <= 2 - try: - x = int(x) - return [x, x] - except (TypeError, ValueError): - pass - h, w = x - return int(h), int(w) + if isinstance(x, Sequence): + return int(x[0]), int(x[1]) + return int(x), int(x) + + +def expand_dhw(x): + if isinstance(x, Sequence): + return int(x[0]), int(x[1]), int(x[2]) + return int(x), int(x), int(x) def linear( @@ -177,11 +177,8 @@ def conv1d( if weight.dtype != dtype: weight = weight.astype(dtype) - inp = expand_dims(inp, 3) - weight = expand_dims(weight, 3) if bias is not None: assert bias.ndim == 3, "the bias dimension of conv1d should be 3" - bias = expand_dims(bias, 3) stride_h = stride pad_h = padding @@ -206,7 +203,6 @@ def conv1d( (output,) = apply(op, inp, weight) if bias is not None: output += bias - output = squeeze(output, 3) return output @@ -314,9 +310,9 @@ def conv3d( D, H, W = 0, 1, 2 - pad = _triple(padding) - stride = _triple_nonzero(stride) - dilate = _triple_nonzero(dilation) + pad = expand_dhw(padding) + stride = expand_dhw(stride) + dilate = expand_dhw(dilation) sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution3D( @@ -572,9 +568,9 @@ def conv_transpose3d( output tensor. """ D, H, W = 0, 1, 2 - pad = _triple(padding) - stride = _triple_nonzero(stride) - dilate = _triple_nonzero(dilation) + pad = expand_dhw(padding) + stride = expand_dhw(stride) + dilate = expand_dhw(dilation) sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution3DBackwardData( @@ -618,9 +614,9 @@ def max_pool2d( """ if stride is None: stride = kernel_size - window_h, window_w = _pair_nonzero(kernel_size) - stride_h, stride_w = _pair_nonzero(stride) - padding_h, padding_w = _pair(padding) + window_h, window_w = expand_hw(kernel_size) + stride_h, stride_w = expand_hw(stride) + padding_h, padding_w = expand_hw(padding) conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) op = builtin.Pooling( @@ -662,9 +658,9 @@ def avg_pool2d( """ if stride is None: stride = kernel_size - window_h, window_w = _pair_nonzero(kernel_size) - stride_h, stride_w = _pair_nonzero(stride) - padding_h, padding_w = _pair(padding) + window_h, window_w = expand_hw(kernel_size) + stride_h, stride_w = expand_hw(stride) + padding_h, padding_w = expand_hw(padding) conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) op = builtin.Pooling( @@ -1779,10 +1775,10 @@ def sliding_window( stride: stride of the window. Default: 1 dilation: dilation of the window. Default: 1 """ - padding_h, padding_w = _pair(padding) - stride_h, stride_w = _pair_nonzero(stride) - dilation_h, dilation_w = _pair_nonzero(dilation) - window_h, window_w = _pair_nonzero(kernel_size) + padding_h, padding_w = expand_hw(padding) + stride_h, stride_w = expand_hw(stride) + dilation_h, dilation_w = expand_hw(dilation) + window_h, window_w = expand_hw(kernel_size) op = builtin.Images2Neibs( pad_h=padding_h, @@ -1818,11 +1814,11 @@ def sliding_window_transpose( stride: stride of the window. Default: 1 dilation: dilation of the window. Default: 1 """ - output_h, output_w = _pair_nonzero(output_size) - padding_h, padding_w = _pair(padding) - stride_h, stride_w = _pair_nonzero(stride) - dilation_h, dilation_w = _pair_nonzero(dilation) - window_h, window_w = _pair_nonzero(kernel_size) + output_h, output_w = expand_hw(output_size) + padding_h, padding_w = expand_hw(padding) + stride_h, stride_w = expand_hw(stride) + dilation_h, dilation_w = expand_hw(dilation) + window_h, window_w = expand_hw(kernel_size) expected_h = ( output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 993ccfa77..0ca7e7342 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -80,19 +80,6 @@ class _BatchNorm(Module): self.track_running_stats == False ), "track_running_stats can not be initilized to False and changed to True later" - inp_shape = inp.shape - _ndims = len(inp_shape) - if _ndims != 4: - origin_shape = inp_shape - if _ndims == 2: - n, c = inp_shape[0], inp_shape[1] - new_shape = (n, c, 1, 1) - elif _ndims == 3: - n, c, h = inp_shape[0], inp_shape[1], inp_shape[2] - new_shape = (n, c, h, 1) - - inp = inp.reshape(new_shape) - _weight = self.weight _bias = self.bias @@ -130,9 +117,6 @@ class _BatchNorm(Module): param_dim=self.param_dim, ) - if _ndims != 4: - output = output.reshape(origin_shape) - return output def _module_info_string(self) -> str: diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 556308e5d..75fea2ba0 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -15,6 +15,7 @@ #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/profiler.h" +#include "megbrain/imperative/transformations/dim_expansion.h" #include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/lazy.h" @@ -61,11 +62,13 @@ struct SymbolVarContext { std::shared_ptr symbol_tsf; std::shared_ptr scalar_tsf; std::shared_ptr dtype_promote_tsf; + std::shared_ptr dim_expansion_tsf; SymbolVarContext(cg::ComputingGraph* graph) { symbol_tsf = std::make_shared(graph); scalar_tsf = std::make_shared(); dtype_promote_tsf = std::make_shared(); + dim_expansion_tsf = std::make_shared(); Transformation::swap_context(context); } @@ -73,6 +76,7 @@ struct SymbolVarContext { symbol_tsf->register_at(Transformation::top()); scalar_tsf->register_at(Transformation::top()); dtype_promote_tsf->register_at(Transformation::top()); + dim_expansion_tsf->register_at(Transformation::top()); } ValueRef symvar2val(py::handle py_symbol_var) { @@ -452,6 +456,8 @@ void init_tensor(py::module m) { std::make_shared()); transformations.register_at( std::make_shared()); + transformations.register_at( + std::make_shared()); static py::exception py_async_error( m, "AsyncError", PyExc_RuntimeError); diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index e02135adb..bafc43693 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -26,13 +26,14 @@ struct TransformationManager { enum Segment { ModuleTrace, DTypePromote, + DimExpansion, Grad, Scalar, Trace, Eval, }; - std::array>, 6> segments; + std::array>, 7> segments; template void register_at(std::shared_ptr transformation) { diff --git a/imperative/python/test/integration/test_dtr.py b/imperative/python/test/integration/test_dtr.py index 68b915948..56b47725d 100644 --- a/imperative/python/test/integration/test_dtr.py +++ b/imperative/python/test/integration/test_dtr.py @@ -91,7 +91,7 @@ class ResNet(M.Module): def run_dtr_resnet1202(): - batch_size = 8 + batch_size = 7 resnet1202 = ResNet(BasicBlock, [200, 200, 200]) opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) gm = GradManager().attach(resnet1202.parameters()) diff --git a/imperative/src/impl/transformations/dim_expansion.cpp b/imperative/src/impl/transformations/dim_expansion.cpp new file mode 100644 index 000000000..2f1c7c309 --- /dev/null +++ b/imperative/src/impl/transformations/dim_expansion.cpp @@ -0,0 +1,95 @@ +#include "megbrain/imperative/transformations/dim_expansion.h" +#include "megbrain/imperative/ops/autogen.h" + +namespace mgb::imperative { + +namespace { +using DimExpansionRule = std::function)>; +static std::unordered_map dim_expansion_rules; + +template +void register_dim_expansion_rules(const DimExpansionRule& rule) { + dim_expansion_rules[T::typeinfo()] = [rule](const OpDef& def, + Span inputs) { + return rule(def.cast_final_safe(), inputs); + }; +} + +ValueRefList conv1d_rule(const OpDef& op, Span inputs) { + bool need_expand = inputs.at(0).shape()->ndim == 3; + if (!need_expand) + return imperative::apply(op, inputs); + + ValueRefList converted(inputs.size()); + std::vector axis = {(int32_t)3}; + for (size_t i = 0; i < inputs.size(); ++i) { + converted[i] = imperative::apply(ApplyOp(*AddAxis::make(axis)), inputs[i])[0]; + } + + auto outputs = imperative::apply(op, converted); + outputs[0] = imperative::apply(ApplyOp(*RemoveAxis::make(axis)), outputs[0])[0]; + return outputs; +} + +ValueRefList bn1d_rule(const OpDef& op, Span inputs) { + size_t ndim = inputs.at(0).shape()->ndim; + bool need_expand = (ndim == 2 || ndim == 3); + if (!need_expand) + return imperative::apply(op, inputs); + + ValueRefList converted(inputs.size()); + std::vector axis = {(int32_t)3}; + if (ndim == 2) { + axis.insert(axis.begin(), (int32_t)2); + } + converted[0] = imperative::apply(ApplyOp(*AddAxis::make(axis)), inputs[0])[0]; + for (size_t i = 1; i < inputs.size(); ++i) { + converted[i] = inputs[i]; + } + + std::reverse(std::begin(axis), std::end(axis)); + auto outputs = imperative::apply(op, converted); + size_t idx = outputs.size() - 1; + outputs[idx] = imperative::apply(ApplyOp(*RemoveAxis::make(axis)), outputs[idx])[0]; + return outputs; +} + +struct DimExpansionRuleRegistry { + DimExpansionRuleRegistry() { + register_dim_expansion_rules(conv1d_rule); + register_dim_expansion_rules(bn1d_rule); + } +} register_helper; + +} // namespace + +ValueRefList DimExpansionTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto apply_op = op.as()) { + auto iter = dim_expansion_rules.find(apply_op->op().dyn_typeinfo()); + if (iter != dim_expansion_rules.end()) { + return iter->second(apply_op->op(), inputs); + } else { + return imperative::apply(op, inputs); + } + } + return imperative::apply(op, inputs); +} + +ValueRef DimExpansionTransformation::unwrap(ValueRef value) { + return value; +} + +std::string DimExpansionTransformation::name() const { + return "DimExpansionTransformation"; +} + +void DimExpansionTransformation::on_register() { + // printf("DimExpansionTransformation has been registered\n"); +} + +void DimExpansionTransformation::on_unregister() noexcept { + // printf("DimExpansionTransformation has been unregistered\n"); +} + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/transformations/dim_expansion.h b/imperative/src/include/megbrain/imperative/transformations/dim_expansion.h new file mode 100644 index 000000000..f4d0b0282 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/dim_expansion.h @@ -0,0 +1,19 @@ +#pragma once + +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/value.h" + +namespace mgb::imperative { + +class DimExpansionTransformation final : public Transformation { +private: +public: + ValueRefList apply_transformation( + const Operator& op, Span inputs) override; + ValueRef unwrap(ValueRef value) override; + std::string name() const override; + void on_register() override; + void on_unregister() noexcept override; +}; + +} // namespace mgb::imperative -- GitLab