提交 a4327c4d 编写于 作者: M Megvii Engine Team

perf(imperative): add dim_expansion transform for conv/bn1d

GitOrigin-RevId: d14a69424d9e15ac9ae29a9e4bcd9b532dc76200
上级 72a70dd6
...@@ -41,7 +41,6 @@ from ..distributed import WORLD, is_distributed ...@@ -41,7 +41,6 @@ from ..distributed import WORLD, is_distributed
from ..jit import exclude_from_trace from ..jit import exclude_from_trace
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_func from ..utils.deprecation import deprecated_func
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
from .debug_param import get_execution_strategy from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import _elwise, exp, log, log1p, maximum, minimum from .elemwise import _elwise, exp, log, log1p, maximum, minimum
...@@ -94,14 +93,15 @@ __all__ = [ ...@@ -94,14 +93,15 @@ __all__ = [
def expand_hw(x): def expand_hw(x):
# NOTE: >1d array is accepted, as long as 1 <= size <= 2 if isinstance(x, Sequence):
try: return int(x[0]), int(x[1])
x = int(x) return int(x), int(x)
return [x, x]
except (TypeError, ValueError):
pass def expand_dhw(x):
h, w = x if isinstance(x, Sequence):
return int(h), int(w) return int(x[0]), int(x[1]), int(x[2])
return int(x), int(x), int(x)
def linear( def linear(
...@@ -177,11 +177,8 @@ def conv1d( ...@@ -177,11 +177,8 @@ def conv1d(
if weight.dtype != dtype: if weight.dtype != dtype:
weight = weight.astype(dtype) weight = weight.astype(dtype)
inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3)
if bias is not None: if bias is not None:
assert bias.ndim == 3, "the bias dimension of conv1d should be 3" assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
bias = expand_dims(bias, 3)
stride_h = stride stride_h = stride
pad_h = padding pad_h = padding
...@@ -206,7 +203,6 @@ def conv1d( ...@@ -206,7 +203,6 @@ def conv1d(
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
output += bias output += bias
output = squeeze(output, 3)
return output return output
...@@ -314,9 +310,9 @@ def conv3d( ...@@ -314,9 +310,9 @@ def conv3d(
D, H, W = 0, 1, 2 D, H, W = 0, 1, 2
pad = _triple(padding) pad = expand_dhw(padding)
stride = _triple_nonzero(stride) stride = expand_dhw(stride)
dilate = _triple_nonzero(dilation) dilate = expand_dhw(dilation)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D( op = builtin.Convolution3D(
...@@ -572,9 +568,9 @@ def conv_transpose3d( ...@@ -572,9 +568,9 @@ def conv_transpose3d(
output tensor. output tensor.
""" """
D, H, W = 0, 1, 2 D, H, W = 0, 1, 2
pad = _triple(padding) pad = expand_dhw(padding)
stride = _triple_nonzero(stride) stride = expand_dhw(stride)
dilate = _triple_nonzero(dilation) dilate = expand_dhw(dilation)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3DBackwardData( op = builtin.Convolution3DBackwardData(
...@@ -618,9 +614,9 @@ def max_pool2d( ...@@ -618,9 +614,9 @@ def max_pool2d(
""" """
if stride is None: if stride is None:
stride = kernel_size stride = kernel_size
window_h, window_w = _pair_nonzero(kernel_size) window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = _pair_nonzero(stride) stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = _pair(padding) padding_h, padding_w = expand_hw(padding)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Pooling( op = builtin.Pooling(
...@@ -662,9 +658,9 @@ def avg_pool2d( ...@@ -662,9 +658,9 @@ def avg_pool2d(
""" """
if stride is None: if stride is None:
stride = kernel_size stride = kernel_size
window_h, window_w = _pair_nonzero(kernel_size) window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = _pair_nonzero(stride) stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = _pair(padding) padding_h, padding_w = expand_hw(padding)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Pooling( op = builtin.Pooling(
...@@ -1779,10 +1775,10 @@ def sliding_window( ...@@ -1779,10 +1775,10 @@ def sliding_window(
stride: stride of the window. Default: 1 stride: stride of the window. Default: 1
dilation: dilation of the window. Default: 1 dilation: dilation of the window. Default: 1
""" """
padding_h, padding_w = _pair(padding) padding_h, padding_w = expand_hw(padding)
stride_h, stride_w = _pair_nonzero(stride) stride_h, stride_w = expand_hw(stride)
dilation_h, dilation_w = _pair_nonzero(dilation) dilation_h, dilation_w = expand_hw(dilation)
window_h, window_w = _pair_nonzero(kernel_size) window_h, window_w = expand_hw(kernel_size)
op = builtin.Images2Neibs( op = builtin.Images2Neibs(
pad_h=padding_h, pad_h=padding_h,
...@@ -1818,11 +1814,11 @@ def sliding_window_transpose( ...@@ -1818,11 +1814,11 @@ def sliding_window_transpose(
stride: stride of the window. Default: 1 stride: stride of the window. Default: 1
dilation: dilation of the window. Default: 1 dilation: dilation of the window. Default: 1
""" """
output_h, output_w = _pair_nonzero(output_size) output_h, output_w = expand_hw(output_size)
padding_h, padding_w = _pair(padding) padding_h, padding_w = expand_hw(padding)
stride_h, stride_w = _pair_nonzero(stride) stride_h, stride_w = expand_hw(stride)
dilation_h, dilation_w = _pair_nonzero(dilation) dilation_h, dilation_w = expand_hw(dilation)
window_h, window_w = _pair_nonzero(kernel_size) window_h, window_w = expand_hw(kernel_size)
expected_h = ( expected_h = (
output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1
......
...@@ -80,19 +80,6 @@ class _BatchNorm(Module): ...@@ -80,19 +80,6 @@ class _BatchNorm(Module):
self.track_running_stats == False self.track_running_stats == False
), "track_running_stats can not be initilized to False and changed to True later" ), "track_running_stats can not be initilized to False and changed to True later"
inp_shape = inp.shape
_ndims = len(inp_shape)
if _ndims != 4:
origin_shape = inp_shape
if _ndims == 2:
n, c = inp_shape[0], inp_shape[1]
new_shape = (n, c, 1, 1)
elif _ndims == 3:
n, c, h = inp_shape[0], inp_shape[1], inp_shape[2]
new_shape = (n, c, h, 1)
inp = inp.reshape(new_shape)
_weight = self.weight _weight = self.weight
_bias = self.bias _bias = self.bias
...@@ -130,9 +117,6 @@ class _BatchNorm(Module): ...@@ -130,9 +117,6 @@ class _BatchNorm(Module):
param_dim=self.param_dim, param_dim=self.param_dim,
) )
if _ndims != 4:
output = output.reshape(origin_shape)
return output return output
def _module_info_string(self) -> str: def _module_info_string(self) -> str:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/lazy.h" #include "megbrain/imperative/transformations/lazy.h"
...@@ -61,11 +62,13 @@ struct SymbolVarContext { ...@@ -61,11 +62,13 @@ struct SymbolVarContext {
std::shared_ptr<SymbolTransformation> symbol_tsf; std::shared_ptr<SymbolTransformation> symbol_tsf;
std::shared_ptr<ScalarTransformation> scalar_tsf; std::shared_ptr<ScalarTransformation> scalar_tsf;
std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf;
std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf;
SymbolVarContext(cg::ComputingGraph* graph) { SymbolVarContext(cg::ComputingGraph* graph) {
symbol_tsf = std::make_shared<SymbolTransformation>(graph); symbol_tsf = std::make_shared<SymbolTransformation>(graph);
scalar_tsf = std::make_shared<ScalarTransformation>(); scalar_tsf = std::make_shared<ScalarTransformation>();
dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>();
dim_expansion_tsf = std::make_shared<DimExpansionTransformation>();
Transformation::swap_context(context); Transformation::swap_context(context);
} }
...@@ -73,6 +76,7 @@ struct SymbolVarContext { ...@@ -73,6 +76,7 @@ struct SymbolVarContext {
symbol_tsf->register_at(Transformation::top()); symbol_tsf->register_at(Transformation::top());
scalar_tsf->register_at(Transformation::top()); scalar_tsf->register_at(Transformation::top());
dtype_promote_tsf->register_at(Transformation::top()); dtype_promote_tsf->register_at(Transformation::top());
dim_expansion_tsf->register_at(Transformation::top());
} }
ValueRef symvar2val(py::handle py_symbol_var) { ValueRef symvar2val(py::handle py_symbol_var) {
...@@ -452,6 +456,8 @@ void init_tensor(py::module m) { ...@@ -452,6 +456,8 @@ void init_tensor(py::module m) {
std::make_shared<ScalarTransformation>()); std::make_shared<ScalarTransformation>());
transformations.register_at<Segment::DTypePromote>( transformations.register_at<Segment::DTypePromote>(
std::make_shared<DTypePromoteTransformation>()); std::make_shared<DTypePromoteTransformation>());
transformations.register_at<Segment::DimExpansion>(
std::make_shared<DimExpansionTransformation>());
static py::exception<interpreter::AsyncError> py_async_error( static py::exception<interpreter::AsyncError> py_async_error(
m, "AsyncError", PyExc_RuntimeError); m, "AsyncError", PyExc_RuntimeError);
......
...@@ -26,13 +26,14 @@ struct TransformationManager { ...@@ -26,13 +26,14 @@ struct TransformationManager {
enum Segment { enum Segment {
ModuleTrace, ModuleTrace,
DTypePromote, DTypePromote,
DimExpansion,
Grad, Grad,
Scalar, Scalar,
Trace, Trace,
Eval, Eval,
}; };
std::array<std::vector<std::shared_ptr<Transformation>>, 6> segments; std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments;
template <Segment segment> template <Segment segment>
void register_at(std::shared_ptr<Transformation> transformation) { void register_at(std::shared_ptr<Transformation> transformation) {
......
...@@ -91,7 +91,7 @@ class ResNet(M.Module): ...@@ -91,7 +91,7 @@ class ResNet(M.Module):
def run_dtr_resnet1202(): def run_dtr_resnet1202():
batch_size = 8 batch_size = 7
resnet1202 = ResNet(BasicBlock, [200, 200, 200]) resnet1202 = ResNet(BasicBlock, [200, 200, 200])
opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
gm = GradManager().attach(resnet1202.parameters()) gm = GradManager().attach(resnet1202.parameters())
......
#include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/ops/autogen.h"
namespace mgb::imperative {
namespace {
using DimExpansionRule = std::function<ValueRefList(const OpDef&, Span<ValueRef>)>;
static std::unordered_map<Typeinfo*, DimExpansionRule> dim_expansion_rules;
template <typename T>
void register_dim_expansion_rules(const DimExpansionRule& rule) {
dim_expansion_rules[T::typeinfo()] = [rule](const OpDef& def,
Span<ValueRef> inputs) {
return rule(def.cast_final_safe<T>(), inputs);
};
}
ValueRefList conv1d_rule(const OpDef& op, Span<ValueRef> inputs) {
bool need_expand = inputs.at(0).shape()->ndim == 3;
if (!need_expand)
return imperative::apply(op, inputs);
ValueRefList converted(inputs.size());
std::vector<int32_t> axis = {(int32_t)3};
for (size_t i = 0; i < inputs.size(); ++i) {
converted[i] = imperative::apply(ApplyOp(*AddAxis::make(axis)), inputs[i])[0];
}
auto outputs = imperative::apply(op, converted);
outputs[0] = imperative::apply(ApplyOp(*RemoveAxis::make(axis)), outputs[0])[0];
return outputs;
}
ValueRefList bn1d_rule(const OpDef& op, Span<ValueRef> inputs) {
size_t ndim = inputs.at(0).shape()->ndim;
bool need_expand = (ndim == 2 || ndim == 3);
if (!need_expand)
return imperative::apply(op, inputs);
ValueRefList converted(inputs.size());
std::vector<int32_t> axis = {(int32_t)3};
if (ndim == 2) {
axis.insert(axis.begin(), (int32_t)2);
}
converted[0] = imperative::apply(ApplyOp(*AddAxis::make(axis)), inputs[0])[0];
for (size_t i = 1; i < inputs.size(); ++i) {
converted[i] = inputs[i];
}
std::reverse(std::begin(axis), std::end(axis));
auto outputs = imperative::apply(op, converted);
size_t idx = outputs.size() - 1;
outputs[idx] = imperative::apply(ApplyOp(*RemoveAxis::make(axis)), outputs[idx])[0];
return outputs;
}
struct DimExpansionRuleRegistry {
DimExpansionRuleRegistry() {
register_dim_expansion_rules<Convolution>(conv1d_rule);
register_dim_expansion_rules<BatchNorm>(bn1d_rule);
}
} register_helper;
} // namespace
ValueRefList DimExpansionTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto apply_op = op.as<ApplyOp>()) {
auto iter = dim_expansion_rules.find(apply_op->op().dyn_typeinfo());
if (iter != dim_expansion_rules.end()) {
return iter->second(apply_op->op(), inputs);
} else {
return imperative::apply(op, inputs);
}
}
return imperative::apply(op, inputs);
}
ValueRef DimExpansionTransformation::unwrap(ValueRef value) {
return value;
}
std::string DimExpansionTransformation::name() const {
return "DimExpansionTransformation";
}
void DimExpansionTransformation::on_register() {
// printf("DimExpansionTransformation has been registered\n");
}
void DimExpansionTransformation::on_unregister() noexcept {
// printf("DimExpansionTransformation has been unregistered\n");
}
} // namespace mgb::imperative
\ No newline at end of file
#pragma once
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/value.h"
namespace mgb::imperative {
class DimExpansionTransformation final : public Transformation {
private:
public:
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override;
std::string name() const override;
void on_register() override;
void on_unregister() noexcept override;
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册