提交 fc633ce4 编写于 作者: M Megvii Engine Team

fix(imperative/amp): fix custom grad in Subgraph

GitOrigin-RevId: 1c728d6ab97e8a49f84bf7e309a288938111d7be
上级 673b295d
......@@ -50,36 +50,36 @@ class autocast:
self._origin_enabled = None
self._origin_high = None
self._origin_low = None
self._origin_compute_mode = None
self._origin_configs = None
def __enter__(self):
self._origin_enabled = amp._enabled
amp._enabled = self.enabled
amp._set_amp_dtype_autocast(self.enabled)
if not self.enabled:
return
if self.enabled:
self._origin_enabled = amp._enabled
self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._enabled = self.enabled
amp._set_amp_dtype_autocast(self.enabled)
amp._set_amp_high_prec_dtype(self.high_prec_dtype)
amp._set_amp_low_prec_dtype(self.low_prec_dtype)
self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._set_amp_high_prec_dtype(self.high_prec_dtype)
amp._set_amp_low_prec_dtype(self.low_prec_dtype)
self._origin_configs = _config._reset_execution_config(compute_mode="float32")
self._origin_configs = _config._reset_execution_config(
compute_mode="float32"
)
def __exit__(self, *args):
amp._enabled = self._origin_enabled
amp._set_amp_dtype_autocast(self._origin_enabled)
if not self.enabled:
return
amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low)
if self.enabled:
amp._enabled = self._origin_enabled
amp._set_amp_dtype_autocast(self._origin_enabled)
amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low)
_config._reset_execution_config(*self._origin_compute_mode)
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not self.enabled:
return func(*args, **kwargs)
with self:
return func(*args, **kwargs)
......
......@@ -10,6 +10,7 @@ from copy import deepcopy
from .. import functional as F
from ..module import Module
from ..tensor import Tensor
from ..core import _config
def _is_nchw_format(param: Tensor):
......@@ -26,10 +27,12 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
else:
raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
# TODO: use initialization from tensor after fixing format setting
if inplace:
x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc")
else:
x = Tensor(x.numpy().transpose(*pattern), format="nhwc")
if x.format != "nhwc":
if inplace:
data = x.numpy().transpose(*pattern)
x[...] = Tensor(data, format="nhwc")
else:
x = Tensor(x.numpy().transpose(*pattern), format="nhwc")
return x
......
......@@ -144,7 +144,9 @@ class GradScaler:
def _check_gradients(self, grads, scale):
if len(grads) == 0:
return False
return _check_non_finite(grads, scale)
rst = _check_non_finite(grads, scale)
rst = rst.numpy()
return rst
def update(self, new_scale: float = None):
r"""Update the scale factor according to whether encountered overflow grad.
......
......@@ -182,7 +182,6 @@ def _reset_execution_config(
deterministic_kernel=None,
async_level=None,
compute_mode=None,
bn_format=None,
auto_format_convert=None,
):
global _benchmark_kernel, _deterministic_kernel, __compute_mode
......@@ -234,11 +233,11 @@ def _override(
def train():
"""
orig_flags = _reset_execution_config(
benchmark_kernel,
deterministic_kernel,
async_level,
compute_mode,
auto_format_convert,
benchmark_kernel=benchmark_kernel,
deterministic_kernel=deterministic_kernel,
async_level=async_level,
compute_mode=compute_mode,
auto_format_convert=auto_format_convert,
)
try:
yield
......
......@@ -64,7 +64,9 @@ class Grad:
continue
grad.suppress()
print("before backward")
self._impl.backward(ys, dys)
print("after backward")
for grad in group:
if grad is self:
......
......@@ -24,6 +24,7 @@ from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._imperative_rt.ops import jit_supported
from .._wrap import as_device
from ..autodiff.grad import Function
from .. import _config
from ..ops import builtin
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype
from .dtype import is_dtype_equal, is_quantize
......
......@@ -1226,12 +1226,16 @@ def batch_norm(
bias = make_full_if_none(bias, 0)
if not training:
op = builtin.BatchNorm(fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps)
op = builtin.BatchNorm(
fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps
)
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret
else:
op = builtin.BatchNorm(avg_factor=1 - momentum, epsilon=eps)
op = builtin.BatchNorm(
avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps
)
if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0)
running_var = make_full_if_none(running_var, 1)
......
......@@ -19,7 +19,6 @@ class _BatchNorm(Module):
affine=True,
track_running_stats=True,
freeze=False,
param_dim="dim_1c11",
**kwargs
):
super(_BatchNorm, self).__init__(**kwargs)
......@@ -30,7 +29,6 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats
self.freeze = freeze
self.param_dim = param_dim
if self.freeze:
assert (
self._track_running_stats_saved
......@@ -104,7 +102,6 @@ class _BatchNorm(Module):
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
param_dim=self.param_dim,
)
return output
......
......@@ -8,6 +8,7 @@ from typing import Union
import numpy as np
from ..core import _config
from ..core._imperative_rt.core2 import (
get_auto_format_convert,
pop_scope,
......@@ -96,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
param._reset(Tensor(param, no_cache=True))
for name, default in self._defaults.items():
if default is required and name not in param_group:
......@@ -119,10 +120,11 @@ class Optimizer(metaclass=ABCMeta):
def _add_state(self, param, state_name, initializer=None):
if initializer is None:
initializer = np.zeros(param.shape, dtype=np.float32)
with _config._override(auto_format_convert=False):
initializer = np.zeros(param.shape, dtype=np.float32)
state_dict = self._state.setdefault(param, {})
assert state_name not in state_dict
state = Tensor(initializer, no_cache=True)
state = Tensor(initializer, no_cache=True, format=param.format)
state_dict[state_name] = state
@abstractmethod
......
......@@ -5,6 +5,7 @@ from typing import Iterable, Union
from ..functional.inplace import _inplace_add_
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
from ..core import _config
class SGD(Optimizer):
......
......@@ -10,7 +10,7 @@ import pytest
import megengine.functional as F
import megengine.module as M
from megengine import Parameter, Tensor, amp, tensor
from megengine import Parameter, Tensor, amp, config
class MyModule(M.Module):
......@@ -39,6 +39,22 @@ class MyModule(M.Module):
@pytest.mark.parametrize("is_inplace", [False, True])
def test_convert_module(is_inplace):
m = MyModule()
expected_shape = {
"i.bn.weight": (1, 1, 1, 4),
"i.bn.bias": (1, 1, 1, 4),
"i.bn.running_mean": (1, 1, 1, 4),
"i.bn.running_var": (1, 1, 1, 4),
"conv.weight": (2, 2, 4, 4, 2),
"conv.bias": (1, 1, 1, 4),
"bn.weight": (1, 1, 1, 4),
"bn.bias": (1, 1, 1, 4),
"bn.running_mean": (1, 1, 1, 4),
"bn.running_var": (1, 1, 1, 4),
"param": (1, 1, 1, 3),
"buff": (1, 1, 1, 3),
}
m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors():
assert param.format == "nhwc"
with config._override(auto_format_convert=False):
assert param.shape == expected_shape[name], name
......@@ -3,6 +3,7 @@ import pytest
import megengine as mge
import megengine.functional as F
import megengine.module as M
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.jit import trace
......@@ -36,9 +37,9 @@ def _compare_nchw_nhwc(data, func, is_symbolic=None):
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic)
# out1 = func(x1)
out1 = func(x1)
out2 = func(x2)
# np.testing.assert_almost_equal(out1, out2, decimal=5)
np.testing.assert_almost_equal(out1, out2, decimal=5)
@pytest.mark.parametrize("is_symbolic", [None])
......@@ -322,30 +323,91 @@ def test_pooling2d(pooling, is_symbolic):
_compare_nchw_nhwc(data, func, is_symbolic)
@pytest.mark.parametrize("is_symbolic", [None])
def test_backward(is_symbolic):
data = np.arange(0, 24).reshape((1, 2, 3, 4))
x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc")
b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc")
gm = GradManager().attach([w, b])
def _compare_backward(inps, model, is_symbolic=None):
def func(*inps):
return model(*inps)
def func(x, w, b):
return F.conv2d(x, w, b)
if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic)
gm = GradManager().attach(model.parameters())
with gm:
if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic)
x = func(x, w, b)
assert x.format == "nhwc"
# test manually convert to NHWC, usually used in detection head
x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2)
gm.backward(x)
print("finish backward", x.format)
# backward grad has no format
np.testing.assert_equal(
w.grad.numpy(), np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
)
np.testing.assert_equal(
b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3))
)
rst = func(*inps)
gm.backward(rst)
expected_grads = [param.grad for param in model.parameters()]
inps = [mge.amp.convert_tensor_format(inp) for inp in inps]
model = mge.amp.convert_module_format(model)
gm = GradManager().attach(model.parameters())
with gm:
rst = func(*inps)
gm.backward(rst)
actual_grads = [param.grad for param in model.parameters()]
for expected, actual in zip(expected_grads, actual_grads):
# print(param.grad)
np.testing.assert_equal(expected.numpy(), actual.numpy())
@pytest.mark.parametrize("is_symbolic", [None])
def test_backward_conv2d_dimshuffle(is_symbolic):
class Net(M.Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(2, 3, 1)
def forward(self, inp):
# test manually convert to NHWC, usually used in detection head
return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2)
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4)))
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
# w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc")
# b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc")
# grads = [
# np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
# np.array([12, 12, 12]).reshape((1, 1, 1, 3)),
# ]
_compare_backward([inp], Net(), is_symbolic)
@pytest.mark.parametrize("is_symbolic", [None])
def test_backward_groupconv2d_bn(is_symbolic):
class Net(M.Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(2, 2, 1, groups=2)
self.bn = M.BatchNorm2d(2)
def forward(self, inp):
# test manually convert to NHWC, usually used in detection head
return self.bn(self.conv(inp))
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4)))
_compare_backward([inp], Net(), is_symbolic)
# def func(x, w, b, bn_w, bn_b):
# x = F.conv2d(x, w, b, groups=2)
# x = F.batch_norm(
# x,
# running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
# running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
# weight=bn_w,
# bias=bn_b,
# training=True,
# inplace=True,
# )
# return x
# data = np.arange(0, 24).reshape((1, 2, 3, 4))
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
# w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc")
# b = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# grads = [
# np.array([66, 210]).reshape((2, 1, 1, 1, 1)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# ]
# _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic)
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
namespace mgb {
namespace imperative {
......@@ -17,7 +19,12 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
const std::string& scope) const {
std::vector<int32_t> pattern;
if (tensor.format() == FT::NHWC && target == FT::NCHW) {
pattern = {0, 3, 1, 2};
// FIXME(czh): temporary fast path for group conv 5D weight.
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
pattern = {0, 1, 4, 2, 3};
} else {
pattern = {0, 3, 1, 2};
}
} else if (tensor.format() == FT::NCHW && target == FT::NHWC) {
pattern = {0, 2, 3, 1};
} else {
......@@ -65,12 +72,22 @@ inline ValueRefList FormatTransformation::wrap_outputs(
namespace {
ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
mgb_assert(shape.ndim == 4);
auto out = ValueShape(shape);
out[3] = shape[2];
out[2] = shape[1];
out[1] = shape[3];
return out;
if (shape.ndim == 4) {
out[1] = shape[3];
out[2] = shape[1];
out[3] = shape[2];
return out;
} else if (shape.ndim == 5) {
out[2] = shape[4];
out[3] = shape[2];
out[4] = shape[3];
return out;
} else {
mgb_throw(
MegBrainError, "Unsupported shape ndim %u in GetAttr(Shape).",
shape.ndim);
}
}
using FormatRule = std::function<ValueRefList(
......@@ -278,10 +295,10 @@ ValueRefList setsubtensor_rule(
inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) {
FT format(FT::DEFAULT);
for (auto& inp : inputs) {
auto&& inp_ref = inp.as_ref(t.value_type());
if (inp_ref && inp_ref->format() != FT::DEFAULT) {
mgb_assert(format == FT::DEFAULT || inp_ref->format() == format);
format = inp_ref->format().type();
auto&& inp_format = inp.cast(t.value_type()).format();
if (inp_format != FT::DEFAULT) {
mgb_assert(format == FT::DEFAULT || inp_format == format);
format = inp_format.type();
}
}
return format;
......@@ -308,13 +325,6 @@ ValueRefList concat_rule(
format);
}
ValueRefList elemwise_rule(
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
FT format = get_inputs_format(inputs, t);
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
}
ValueRefList identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) {
// mgb_assert(inputs.size() == 1);
......@@ -336,24 +346,49 @@ ValueRefList batchnorm_rule(
return identity_rule_helper(op, inputs, t);
}
ValueRefList checknonfinite_rule(
const CheckNonFinite& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
auto&& inputs_ = t.unwrap_inputs(inputs);
auto&& outputs_ = imperative::apply(op, inputs_);
return t.wrap_outputs(outputs_);
}
// clang-format off
#define FOREACH_IDENTITY_OP(cb) \
cb(Copy) \
cb(FastpathCopy) \
cb(TypeCvt) \
cb(Dropout) \
#define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \
cb(Elemwise) \
cb(CompiledOp) \
cb(SubgraphOp)
#define FOREACH_IDENTITY_OP(cb) \
cb(Copy) \
cb(FastpathCopy) \
cb(TypeCvt) \
cb(Dropout) \
cb(Identity)
#define FOREACH_FORMAT_OP(cb) \
cb(AdaptivePooling) \
cb(WarpAffine) \
#define FOREACH_FORMAT_OP(cb) \
cb(AdaptivePooling) \
cb(WarpAffine) \
cb(Resize)
#define FOREACH_FORMAT_POLICY_OP(cb)\
cb(Pooling) \
#define FOREACH_FORMAT_POLICY_OP(cb) \
cb(Pooling) \
cb(Convolution)
// clang-format on
// multi inputs op without params
#define CREATE_MULTI_INPS_NO_PARAM_OP_RULE(Op) \
ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
FT format = get_inputs_format(inputs, t); \
return t.wrap_outputs( \
imperative::apply(_op, t.unwrap_inputs(inputs)), format); \
}
FOREACH_MULTI_INPS_NO_PARAM_OP(CREATE_MULTI_INPS_NO_PARAM_OP_RULE)
#undef CREATE_MULTI_INPS_NO_PARAM_OP_RULE
// identity op
#define CREATE_IDENTITY_OP_RULE(Op) \
ValueRefList Op##_rule( \
......@@ -409,8 +444,9 @@ struct FormatRuleRegistry {
register_format_rule(setsubtensor_rule<SetSubtensor>);
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
register_format_rule(concat_rule);
register_format_rule(elemwise_rule);
register_format_rule(batchnorm_rule);
register_format_rule(checknonfinite_rule);
FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE)
FOREACH_IDENTITY_OP(REGISTER_OP_RULE)
FOREACH_FORMAT_OP(REGISTER_OP_RULE)
FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE)
......@@ -455,27 +491,73 @@ ValueRefList FormatTransformation::apply_transformation(
return imperative::apply(op, unwrap_inputs(inputs));
}
} else if (op.is<GetFormat>()) {
bool is_formatted_tensor = inputs.item().is(m_value_type);
if (is_formatted_tensor) {
return {FormatValue::make(inputs[0].cast(m_value_type).format())};
auto&& inp_ref = inputs[0].as_ref(m_value_type);
if (inp_ref) {
return {FormatValue::make(inp_ref->format())};
} else {
mgb_log_warn(
"Not FormattedTensorValue input for GetFormat op: %s",
inputs[0].to_string().c_str());
"Not FormattedTensorValue input for GetFormat op: %s, %s",
op.to_string().c_str(), inputs[0].to_string().c_str());
return {FormatValue::make(FT::DEFAULT)};
}
} else if (op.is<Operator::IdentityLike>()) {
bool is_formatted_tensor = inputs.item().is(m_value_type);
if (is_formatted_tensor) {
auto&& format = inputs[0].cast(m_value_type).format();
auto&& inp_ref = inputs[0].as_ref(m_value_type);
if (inp_ref) {
auto&& format = inp_ref->format();
return wrap_outputs(
imperative::apply(op, unwrap_inputs(inputs)), format.type());
} else {
mgb_log_warn(
"Not FormattedTensorValue input for IdentityLike op: %s",
inputs[0].to_string().c_str());
"Not FormattedTensorValue input for IdentityLike op: %s, %s",
op.to_string().c_str(), inputs[0].to_string().c_str());
return imperative::apply(op, inputs);
}
} else if (op.is<AttachGrad>()) {
auto&& inp_ref = inputs[0].as_ref(m_value_type);
if (inp_ref) {
auto format = inp_ref->format();
GenericFunction callback =
(GenericFunction&)inputs[1].cast<FunctionValue>();
GenericFunction new_callback =
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList {
auto wrapped_inputs = SmallVector<ValueRef>{
this->value_type().make(inputs_.item(), format.type())};
auto ret = callback(wrapped_inputs);
return ret;
};
auto&& outputs = imperative::apply(
op, inp_ref->value(), FunctionValue::make(new_callback));
return wrap_outputs(outputs, format.type());
} else {
mgb_log_warn(
"Not FormattedTensorValue input for AttachGrad op: %s, %s",
op.to_string().c_str(), inputs[0].to_string().c_str());
return imperative::apply(op, inputs);
}
} else if (auto* set_grad = op.as<SetGrad>()) {
size_t nr_inputs = set_grad->nr_inputs();
size_t nr_outputs = inputs.size() - nr_inputs;
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs};
Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs};
// run original apply.
// grads needn't to unwrap and wrap, which will be unwrapped in GradTrans
auto&& outputs = imperative::apply(op, unwrap_inputs(inputs));
// handle output's formats
auto wrapped_outputs = ValueRefList(nr_outputs);
for (size_t i = 0; i < nr_outputs; ++i) {
if (auto output_ref = outputs_[i].as_ref(m_value_type)) {
wrapped_outputs[i] =
m_value_type.make(outputs[i], output_ref->format().type());
} else {
mgb_log_warn(
"Not FormattedTensorValue outputs for SetGrad op: %s, %s",
op.to_string().c_str(), inputs_[i].to_string().c_str());
wrapped_outputs[i] = m_value_type.make(outputs[i], FT::DEFAULT);
}
}
return wrapped_outputs;
} else {
return imperative::apply(op, unwrap_inputs(inputs));
}
......
......@@ -47,7 +47,10 @@ public:
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is(m_value_type));
//mgb_assert(!value.is(m_value_type));
if (auto format_val = value.as_ref(m_value_type)) {
return format_val->value();
}
return value;
}
......
......@@ -377,6 +377,8 @@ public:
SetGrad(GenericFunction grad_fn, size_t nr_inputs)
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {}
std::shared_ptr<GradKey> key() const { return m_key; }
GenericFunction grad_fn() const { return m_grad_fn; }
size_t nr_inputs() const { return m_nr_inputs; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册