提交 c28a875f 编写于 作者: M Megvii Engine Team

fix(imperative/amp): adapt new transformation

GitOrigin-RevId: 6edd577a70a8ea0ae00fbde0a6e4034273a30867
上级 fd41302c
......@@ -50,8 +50,6 @@ class autocast:
self._origin_enabled = None
self._origin_high = None
self._origin_low = None
self._origin_compute_mode = None
self._origin_configs = None
def __enter__(self):
......@@ -75,7 +73,7 @@ class autocast:
amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low)
_config._reset_execution_config(*self._origin_compute_mode)
_config._reset_execution_config(*self._origin_configs)
def __call__(self, func):
@functools.wraps(func)
......
......@@ -15,11 +15,14 @@ from ..core import _config
def _is_nchw_format(param: Tensor):
# TODO: use better condition
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc"
return (param.ndim == 4 or param.ndim == 5) and param.format != "nhwc"
def convert_tensor_format(x: Tensor, inplace: bool = True):
"""Convert NCHW Tensor to NHWC Tensor."""
if not _is_nchw_format(x):
return x
if x.ndim == 4:
pattern = (0, 2, 3, 1)
elif x.ndim == 5:
......@@ -29,8 +32,9 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
# TODO: use initialization from tensor after fixing format setting
if x.format != "nhwc":
if inplace:
# reset will destroy backward grad
# hostvalue should still be valid, so no d2h cost.
data = x.numpy().transpose(*pattern)
# reset will destroy existed backward grad
x[...] = Tensor(data, format="nhwc")
else:
# use mge interface to maintain grad
......@@ -45,7 +49,5 @@ def convert_module_format(module: Module, inplace: bool = True):
module = deepcopy(module)
for name, param in module.named_tensors():
if _is_nchw_format(param):
# hostvalue should still be valid, so no d2h cost.
convert_tensor_format(param, inplace=True)
convert_tensor_format(param, inplace=True)
return module
......@@ -64,9 +64,7 @@ class Grad:
continue
grad.suppress()
print("before backward")
self._impl.backward(ys, dys)
print("after backward")
for grad in group:
if grad is self:
......
......@@ -245,8 +245,6 @@ def conv2d(
sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
with _config._override(auto_format_convert=False):
print(compute_mode, inp.shape, inp.format, weight.shape, weight.format)
op = builtin.Convolution(
stride_h=stride_h,
stride_w=stride_w,
......
......@@ -320,7 +320,7 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) {
}
}
py::object device_obj = device2obj(device, true);
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none());
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none(), py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
}
......
......@@ -35,6 +35,7 @@ def test_basic():
b.format = "nhwc"
assert b.format == "nhwc"
def _compare_nchw_nhwc(data, func, is_symbolic=None):
x1 = tensor(data)
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
......@@ -335,21 +336,42 @@ def _compare_backward(inps, model, is_symbolic=None):
gm = GradManager().attach(model.parameters())
with gm:
rst = func(*inps)
gm.backward(rst)
expected_grads = [param.grad for param in model.parameters()]
with mge.amp.autocast():
rst = func(*inps)
gm.backward(rst)
expected_grads = [param.grad.numpy() for param in gm.attached_tensors()]
for param in gm.attached_tensors():
param.grad = None
inps = [mge.amp.convert_tensor_format(inp) for inp in inps]
model = mge.amp.convert_module_format(model)
gm = GradManager().attach(model.parameters())
with gm:
rst = func(*inps)
gm.backward(rst)
actual_grads = [param.grad for param in model.parameters()]
with mge.amp.autocast():
rst = func(*inps)
gm.backward(rst)
actual_grads = [param.grad.numpy() for param in gm.attached_tensors()]
for expected, actual in zip(expected_grads, actual_grads):
# print(param.grad)
np.testing.assert_equal(expected.numpy(), actual.numpy())
assert expected is not None
assert actual is not None
np.testing.assert_almost_equal(expected, actual, decimal=5)
@pytest.mark.parametrize("is_symbolic", [None])
def test_backward_basic(is_symbolic):
class Net(M.Module):
def __init__(self):
super().__init__()
self.w = mge.Parameter([[2.0], [4.0], [6.0]])
self.b = mge.Parameter(-1.0)
def forward(self, inp):
return F.matmul(inp, self.w) + self.b
inp = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
_compare_backward([inp], Net(), is_symbolic)
@pytest.mark.parametrize("is_symbolic", [None])
......@@ -379,14 +401,15 @@ def test_backward_groupconv2d_bn(is_symbolic):
class Net(M.Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(2, 2, 1, groups=2)
self.bn = M.BatchNorm2d(2)
self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2)
self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2)
# self.bn = M.BatchNorm2d(2048)
def forward(self, inp):
# test manually convert to NHWC, usually used in detection head
return self.bn(self.conv(inp))
return self.conv1(self.conv0(inp))
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4)))
inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32"))
_compare_backward([inp], Net(), is_symbolic)
# def func(x, w, b, bn_w, bn_b):
# x = F.conv2d(x, w, b, groups=2)
......
......@@ -260,6 +260,7 @@ void ChannelImpl::dispatch_default_cpu(
CompNode output_cn;
{
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD dispatch_default_cpu");
for (auto&& info : input_infos) {
auto input_cn = info->desc.comp_node;
if (!output_cn.valid()) {
......@@ -277,6 +278,7 @@ void ChannelImpl::dispatch_default_cpu(
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
}
}
//mgb_log_warn("<<< MGB_LOCK_GUARD dispatch_default_cpu");
}
SmallVector<DeviceTensorND> output_tensornds;
......@@ -530,7 +532,9 @@ void ChannelImpl::sync() {
void ChannelImpl::sync_impl() {
m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD sync_impl");
check_worker_exc_unsafe();
//mgb_log_warn("<<< MGB_LOCK_GUARD sync_impl");
}
void ChannelImpl::close() {
......@@ -689,6 +693,7 @@ ChannelImpl::~ChannelImpl() {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
auto& state = get_worker_state();
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD produce_tensor");
m_dtr.update_used_time(dest);
MGB_RECORD_EVENT(
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
......@@ -715,16 +720,19 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
m_dtr.insert_candidate(dest);
}
notify_tensor_unsafe(dest);
//mgb_log_warn("<<< MGB_LOCK_GUARD produce_tensor");
}
void ChannelImpl::release_tensor(TensorInfo* dest) {
MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD release_tensor");
dest->ptr.reset();
auto& state = get_worker_state();
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(dest);
}
//mgb_log_warn("<<< MGB_LOCK_GUARD release_tensor");
}
void ChannelImpl::regenerate(TensorInfo* dest) {
......@@ -1000,6 +1008,7 @@ bool ChannelImpl::check_available() {
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor");
mgb_assert(!m_waitee, "duplicate waitee");
m_waitee = info;
m_waitee_id = Profiler::next_id();
......@@ -1010,6 +1019,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
if (require_host && !host_available()) {
// avoid dead lock
lock.unlock();
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor unlock");
if (Profiler::is_profiling()) {
m_worker.add_task(
{Profiler::next_id(), GetValue{info},
......@@ -1021,18 +1031,21 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
});
}
lock.lock();
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor lock");
wait_host = true;
}
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return require_host ? host_available() : static_cast<bool>(info->ptr);
});
//mgb_log_warn("after cv wait");
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr;
if (wait_host) {
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
}
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor");
return info->ptr;
}
......@@ -1040,6 +1053,7 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
if (info == m_waitee) {
MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
m_cv.notify_all();
//mgb_log_warn("cv notify_all");
}
}
......@@ -1102,6 +1116,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
using namespace ranges::views;
auto& state = get_worker_state();
auto& options = state.options;
//mgb_log_warn("process_one_task %s", to_string<Command>(icmd).c_str());
// TODO: remove std::visit for support osx 10.12
auto cmd_visitor = [&](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
......@@ -1123,9 +1138,11 @@ void ChannelImpl::process_one_task(Command& icmd) {
for (auto& i : cmd.inputs) {
if (mgb_unlikely(i->invalid)) {
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD ApplyOp");
for (auto& i : cmd.outputs) {
i->invalid = true;
}
//mgb_log_warn("<<< MGB_LOCK_GUARD ApplyOp");
return;
}
}
......@@ -1210,8 +1227,10 @@ void ChannelImpl::process_one_task(Command& icmd) {
}
cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD GetValue");
notify_tensor_unsafe(cmd.dest);
imperative_log_profile_end("GetValue");
//mgb_log_warn("<<< MGB_LOCK_GUARD GetValue");
} else if constexpr (std::is_same_v<T, Drop>) {
if (cmd.dest->invalid)
return;
......@@ -1271,6 +1290,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
cmd_visitor(cmd);
} catch (...) {
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD catch exception");
if constexpr (std::is_same_v<T, ApplyOp>) {
for (auto oup : cmd.outputs) {
oup->invalid = true;
......@@ -1283,6 +1303,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
if (m_waitee) {
notify_tensor_unsafe(m_waitee);
}
//mgb_log_warn("<<< MGB_LOCK_GUARD catch exception");
}
},
icmd.data);
......
......@@ -33,9 +33,8 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
tensor.format().to_string().c_str(),
Format(target).to_string().c_str());
}
auto output = imperative::apply(
*Dimshuffle::make(pattern, scope),
SmallVector<ValueRef>{tensor.value()})[0];
auto output =
imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0];
return m_value_type.make(output, target);
}
......@@ -90,6 +89,27 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
}
}
std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape) {
auto out = std::vector<int32_t>(shape);
if (shape.size() == 4) {
out[1] = shape[2];
out[2] = shape[3];
out[3] = shape[1];
return out;
} else if (shape.size() == 5) {
// GIOHW -> GIHWO
out[2] = shape[3];
out[3] = shape[4];
out[4] = shape[2];
return out;
} else {
mgb_throw(
MegBrainError,
"Unsupported shape ndim %u in convert NCHW shape to NHWC.",
shape.size());
}
}
using FormatRule = std::function<ValueRefList(
const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>;
static std::unordered_map<Typeinfo*, FormatRule> format_rules;
......@@ -156,22 +176,38 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) {
ValueRefList reshape_rule(
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
mgb_assert(inputs.size() == 2);
mgb_assert(inputs.size() >= 1);
auto& src = inputs[0].cast(t.value_type());
if (auto_convert && src.format() == FT::NHWC) {
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply(
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
return t.wrap_outputs(outputs);
if (inputs.size() == 1) {
if (op.shape.size() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape);
auto outputs = imperative::apply(
*Reshape::make(op.axis, nhwc_shape), {t.unwrap_input(inputs[0])});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(op, {nchw_src});
return t.wrap_outputs(outputs);
}
} else if (inputs.size() == 2) {
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply(
op,
SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
return t.wrap_outputs(outputs);
}
}
}
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
......@@ -180,22 +216,38 @@ ValueRefList reshape_rule(
ValueRefList broadcast_rule(
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
mgb_assert(inputs.size() == 2);
mgb_assert(inputs.size() >= 1);
auto& src = inputs[0].cast(t.value_type());
if (auto_convert && src.format() == FT::NHWC) {
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply(
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
return t.wrap_outputs(outputs);
if (inputs.size() == 1) {
if (op.shape.size() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape);
auto outputs = imperative::apply(
*Broadcast::make(nhwc_shape), {t.unwrap_input(inputs[0])});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(op, {nchw_src});
return t.wrap_outputs(outputs);
}
} else if (inputs.size() == 2) {
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply(
op,
SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
return t.wrap_outputs(outputs);
}
}
}
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
......@@ -240,8 +292,7 @@ ValueRefList subtensor_rule(
// only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) {
return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0],
src.format())};
imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())};
}
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
auto outputs = imperative::apply(
......@@ -263,8 +314,7 @@ ValueRefList setsubtensor_rule(
// only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) {
return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0],
src.format())};
imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())};
}
// value has been broadcasted to src's fake NCHW shape.
auto& value = inputs[1].cast(t.value_type());
......@@ -329,8 +379,7 @@ ValueRefList identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) {
// mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast(t.value_type());
return t.wrap_outputs(
imperative::apply(op, t.unwrap_inputs(inputs)), src.format());
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format());
}
ValueRefList batchnorm_rule(
......@@ -457,6 +506,7 @@ struct FormatRuleRegistry {
ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
//mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
if (auto* apply_op = op.as<ApplyOp>()) {
// all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
......@@ -485,7 +535,7 @@ ValueRefList FormatTransformation::apply_transformation(
}
case GetAttr::Value: {
auto nchw_src = unwrap_input(to(src, FT::NCHW, ""));
return imperative::apply(op, SmallVector<ValueRef>{nchw_src});
return imperative::apply(op, {nchw_src});
}
default:
return imperative::apply(op, unwrap_inputs(inputs));
......@@ -508,8 +558,7 @@ ValueRefList FormatTransformation::apply_transformation(
auto&& inp_ref = inputs[0].as_ref(m_value_type);
if (inp_ref) {
auto&& format = inp_ref->format();
return wrap_outputs(
imperative::apply(op, unwrap_inputs(inputs)), format);
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
} else {
mgb_log_warn(
"Not FormattedTensorValue input for IdentityLike op: %s, %s",
......@@ -522,6 +571,7 @@ ValueRefList FormatTransformation::apply_transformation(
auto format = inp_ref->format();
GenericFunction callback =
(GenericFunction&)inputs[1].cast<FunctionValue>();
// make param grads as FormattedTensor
GenericFunction new_callback =
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList {
auto wrapped_inputs = SmallVector<ValueRef>{
......@@ -531,6 +581,7 @@ ValueRefList FormatTransformation::apply_transformation(
};
auto&& outputs = imperative::apply(
op, inp_ref->value(), FunctionValue::make(new_callback));
// make params(GradValue) as FormattedTensor
return wrap_outputs(outputs, format);
} else {
mgb_log_warn(
......@@ -539,6 +590,7 @@ ValueRefList FormatTransformation::apply_transformation(
return imperative::apply(op, inputs);
}
} else if (auto* set_grad = op.as<SetGrad>()) {
// make grads in Function backward as FormattedTensor
size_t nr_inputs = set_grad->nr_inputs();
size_t nr_outputs = inputs.size() - nr_inputs;
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs};
......
......@@ -377,8 +377,6 @@ public:
SetGrad(GenericFunction grad_fn, size_t nr_inputs)
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {}
std::shared_ptr<GradKey> key() const { return m_key; }
GenericFunction grad_fn() const { return m_grad_fn; }
size_t nr_inputs() const { return m_nr_inputs; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册