提交 c28a875f 编写于 作者: M Megvii Engine Team

fix(imperative/amp): adapt new transformation

GitOrigin-RevId: 6edd577a70a8ea0ae00fbde0a6e4034273a30867
上级 fd41302c
...@@ -50,8 +50,6 @@ class autocast: ...@@ -50,8 +50,6 @@ class autocast:
self._origin_enabled = None self._origin_enabled = None
self._origin_high = None self._origin_high = None
self._origin_low = None self._origin_low = None
self._origin_compute_mode = None
self._origin_configs = None self._origin_configs = None
def __enter__(self): def __enter__(self):
...@@ -75,7 +73,7 @@ class autocast: ...@@ -75,7 +73,7 @@ class autocast:
amp._set_amp_high_prec_dtype(self._origin_high) amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low) amp._set_amp_low_prec_dtype(self._origin_low)
_config._reset_execution_config(*self._origin_compute_mode) _config._reset_execution_config(*self._origin_configs)
def __call__(self, func): def __call__(self, func):
@functools.wraps(func) @functools.wraps(func)
......
...@@ -15,11 +15,14 @@ from ..core import _config ...@@ -15,11 +15,14 @@ from ..core import _config
def _is_nchw_format(param: Tensor): def _is_nchw_format(param: Tensor):
# TODO: use better condition # TODO: use better condition
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc" return (param.ndim == 4 or param.ndim == 5) and param.format != "nhwc"
def convert_tensor_format(x: Tensor, inplace: bool = True): def convert_tensor_format(x: Tensor, inplace: bool = True):
"""Convert NCHW Tensor to NHWC Tensor.""" """Convert NCHW Tensor to NHWC Tensor."""
if not _is_nchw_format(x):
return x
if x.ndim == 4: if x.ndim == 4:
pattern = (0, 2, 3, 1) pattern = (0, 2, 3, 1)
elif x.ndim == 5: elif x.ndim == 5:
...@@ -29,8 +32,9 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): ...@@ -29,8 +32,9 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
# TODO: use initialization from tensor after fixing format setting # TODO: use initialization from tensor after fixing format setting
if x.format != "nhwc": if x.format != "nhwc":
if inplace: if inplace:
# reset will destroy backward grad # hostvalue should still be valid, so no d2h cost.
data = x.numpy().transpose(*pattern) data = x.numpy().transpose(*pattern)
# reset will destroy existed backward grad
x[...] = Tensor(data, format="nhwc") x[...] = Tensor(data, format="nhwc")
else: else:
# use mge interface to maintain grad # use mge interface to maintain grad
...@@ -45,7 +49,5 @@ def convert_module_format(module: Module, inplace: bool = True): ...@@ -45,7 +49,5 @@ def convert_module_format(module: Module, inplace: bool = True):
module = deepcopy(module) module = deepcopy(module)
for name, param in module.named_tensors(): for name, param in module.named_tensors():
if _is_nchw_format(param):
# hostvalue should still be valid, so no d2h cost.
convert_tensor_format(param, inplace=True) convert_tensor_format(param, inplace=True)
return module return module
...@@ -64,9 +64,7 @@ class Grad: ...@@ -64,9 +64,7 @@ class Grad:
continue continue
grad.suppress() grad.suppress()
print("before backward")
self._impl.backward(ys, dys) self._impl.backward(ys, dys)
print("after backward")
for grad in group: for grad in group:
if grad is self: if grad is self:
......
...@@ -245,8 +245,6 @@ def conv2d( ...@@ -245,8 +245,6 @@ def conv2d(
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
with _config._override(auto_format_convert=False):
print(compute_mode, inp.shape, inp.format, weight.shape, weight.format)
op = builtin.Convolution( op = builtin.Convolution(
stride_h=stride_h, stride_h=stride_h,
stride_w=stride_w, stride_w=stride_w,
......
...@@ -320,7 +320,7 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) { ...@@ -320,7 +320,7 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) {
} }
} }
py::object device_obj = device2obj(device, true); py::object device_obj = device2obj(device, true);
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none(), py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
} }
......
...@@ -35,6 +35,7 @@ def test_basic(): ...@@ -35,6 +35,7 @@ def test_basic():
b.format = "nhwc" b.format = "nhwc"
assert b.format == "nhwc" assert b.format == "nhwc"
def _compare_nchw_nhwc(data, func, is_symbolic=None): def _compare_nchw_nhwc(data, func, is_symbolic=None):
x1 = tensor(data) x1 = tensor(data)
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
...@@ -335,21 +336,42 @@ def _compare_backward(inps, model, is_symbolic=None): ...@@ -335,21 +336,42 @@ def _compare_backward(inps, model, is_symbolic=None):
gm = GradManager().attach(model.parameters()) gm = GradManager().attach(model.parameters())
with gm: with gm:
with mge.amp.autocast():
rst = func(*inps) rst = func(*inps)
gm.backward(rst) gm.backward(rst)
expected_grads = [param.grad for param in model.parameters()] expected_grads = [param.grad.numpy() for param in gm.attached_tensors()]
for param in gm.attached_tensors():
param.grad = None
inps = [mge.amp.convert_tensor_format(inp) for inp in inps] inps = [mge.amp.convert_tensor_format(inp) for inp in inps]
model = mge.amp.convert_module_format(model) model = mge.amp.convert_module_format(model)
gm = GradManager().attach(model.parameters()) gm = GradManager().attach(model.parameters())
with gm: with gm:
with mge.amp.autocast():
rst = func(*inps) rst = func(*inps)
gm.backward(rst) gm.backward(rst)
actual_grads = [param.grad for param in model.parameters()] actual_grads = [param.grad.numpy() for param in gm.attached_tensors()]
for expected, actual in zip(expected_grads, actual_grads): for expected, actual in zip(expected_grads, actual_grads):
# print(param.grad) assert expected is not None
np.testing.assert_equal(expected.numpy(), actual.numpy()) assert actual is not None
np.testing.assert_almost_equal(expected, actual, decimal=5)
@pytest.mark.parametrize("is_symbolic", [None])
def test_backward_basic(is_symbolic):
class Net(M.Module):
def __init__(self):
super().__init__()
self.w = mge.Parameter([[2.0], [4.0], [6.0]])
self.b = mge.Parameter(-1.0)
def forward(self, inp):
return F.matmul(inp, self.w) + self.b
inp = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
_compare_backward([inp], Net(), is_symbolic)
@pytest.mark.parametrize("is_symbolic", [None]) @pytest.mark.parametrize("is_symbolic", [None])
...@@ -379,14 +401,15 @@ def test_backward_groupconv2d_bn(is_symbolic): ...@@ -379,14 +401,15 @@ def test_backward_groupconv2d_bn(is_symbolic):
class Net(M.Module): class Net(M.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = M.Conv2d(2, 2, 1, groups=2) self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2)
self.bn = M.BatchNorm2d(2) self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2)
# self.bn = M.BatchNorm2d(2048)
def forward(self, inp): def forward(self, inp):
# test manually convert to NHWC, usually used in detection head # test manually convert to NHWC, usually used in detection head
return self.bn(self.conv(inp)) return self.conv1(self.conv0(inp))
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32"))
_compare_backward([inp], Net(), is_symbolic) _compare_backward([inp], Net(), is_symbolic)
# def func(x, w, b, bn_w, bn_b): # def func(x, w, b, bn_w, bn_b):
# x = F.conv2d(x, w, b, groups=2) # x = F.conv2d(x, w, b, groups=2)
......
...@@ -260,6 +260,7 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -260,6 +260,7 @@ void ChannelImpl::dispatch_default_cpu(
CompNode output_cn; CompNode output_cn;
{ {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD dispatch_default_cpu");
for (auto&& info : input_infos) { for (auto&& info : input_infos) {
auto input_cn = info->desc.comp_node; auto input_cn = info->desc.comp_node;
if (!output_cn.valid()) { if (!output_cn.valid()) {
...@@ -277,6 +278,7 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -277,6 +278,7 @@ void ChannelImpl::dispatch_default_cpu(
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu()); input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
} }
} }
//mgb_log_warn("<<< MGB_LOCK_GUARD dispatch_default_cpu");
} }
SmallVector<DeviceTensorND> output_tensornds; SmallVector<DeviceTensorND> output_tensornds;
...@@ -530,7 +532,9 @@ void ChannelImpl::sync() { ...@@ -530,7 +532,9 @@ void ChannelImpl::sync() {
void ChannelImpl::sync_impl() { void ChannelImpl::sync_impl() {
m_worker.wait_all_task_finish(); m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD sync_impl");
check_worker_exc_unsafe(); check_worker_exc_unsafe();
//mgb_log_warn("<<< MGB_LOCK_GUARD sync_impl");
} }
void ChannelImpl::close() { void ChannelImpl::close() {
...@@ -689,6 +693,7 @@ ChannelImpl::~ChannelImpl() { ...@@ -689,6 +693,7 @@ ChannelImpl::~ChannelImpl() {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
auto& state = get_worker_state(); auto& state = get_worker_state();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD produce_tensor");
m_dtr.update_used_time(dest); m_dtr.update_used_time(dest);
MGB_RECORD_EVENT( MGB_RECORD_EVENT(
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
...@@ -715,16 +720,19 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { ...@@ -715,16 +720,19 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
m_dtr.insert_candidate(dest); m_dtr.insert_candidate(dest);
} }
notify_tensor_unsafe(dest); notify_tensor_unsafe(dest);
//mgb_log_warn("<<< MGB_LOCK_GUARD produce_tensor");
} }
void ChannelImpl::release_tensor(TensorInfo* dest) { void ChannelImpl::release_tensor(TensorInfo* dest) {
MGB_RECORD_EVENT(TensorReleaseEvent, dest->id); MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD release_tensor");
dest->ptr.reset(); dest->ptr.reset();
auto& state = get_worker_state(); auto& state = get_worker_state();
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(dest); m_dtr.erase_candidate(dest);
} }
//mgb_log_warn("<<< MGB_LOCK_GUARD release_tensor");
} }
void ChannelImpl::regenerate(TensorInfo* dest) { void ChannelImpl::regenerate(TensorInfo* dest) {
...@@ -1000,6 +1008,7 @@ bool ChannelImpl::check_available() { ...@@ -1000,6 +1008,7 @@ bool ChannelImpl::check_available() {
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex); std::unique_lock<decltype(m_mutex)> lock(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor");
mgb_assert(!m_waitee, "duplicate waitee"); mgb_assert(!m_waitee, "duplicate waitee");
m_waitee = info; m_waitee = info;
m_waitee_id = Profiler::next_id(); m_waitee_id = Profiler::next_id();
...@@ -1010,6 +1019,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -1010,6 +1019,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
if (require_host && !host_available()) { if (require_host && !host_available()) {
// avoid dead lock // avoid dead lock
lock.unlock(); lock.unlock();
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor unlock");
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
m_worker.add_task( m_worker.add_task(
{Profiler::next_id(), GetValue{info}, {Profiler::next_id(), GetValue{info},
...@@ -1021,18 +1031,21 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -1021,18 +1031,21 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}); });
} }
lock.lock(); lock.lock();
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor lock");
wait_host = true; wait_host = true;
} }
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return require_host ? host_available() : static_cast<bool>(info->ptr); return require_host ? host_available() : static_cast<bool>(info->ptr);
}); });
//mgb_log_warn("after cv wait");
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr; m_waitee = nullptr;
if (wait_host) { if (wait_host) {
auto err = info->ptr->comp_node().check_async_error(); auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what()); mgb_assert(!err, "%s", err->what());
} }
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor");
return info->ptr; return info->ptr;
} }
...@@ -1040,6 +1053,7 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) { ...@@ -1040,6 +1053,7 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
if (info == m_waitee) { if (info == m_waitee) {
MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id); MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
m_cv.notify_all(); m_cv.notify_all();
//mgb_log_warn("cv notify_all");
} }
} }
...@@ -1102,6 +1116,7 @@ void ChannelImpl::process_one_task(Command& icmd) { ...@@ -1102,6 +1116,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
using namespace ranges::views; using namespace ranges::views;
auto& state = get_worker_state(); auto& state = get_worker_state();
auto& options = state.options; auto& options = state.options;
//mgb_log_warn("process_one_task %s", to_string<Command>(icmd).c_str());
// TODO: remove std::visit for support osx 10.12 // TODO: remove std::visit for support osx 10.12
auto cmd_visitor = [&](const auto& cmd) { auto cmd_visitor = [&](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>; using T = std::decay_t<decltype(cmd)>;
...@@ -1123,9 +1138,11 @@ void ChannelImpl::process_one_task(Command& icmd) { ...@@ -1123,9 +1138,11 @@ void ChannelImpl::process_one_task(Command& icmd) {
for (auto& i : cmd.inputs) { for (auto& i : cmd.inputs) {
if (mgb_unlikely(i->invalid)) { if (mgb_unlikely(i->invalid)) {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD ApplyOp");
for (auto& i : cmd.outputs) { for (auto& i : cmd.outputs) {
i->invalid = true; i->invalid = true;
} }
//mgb_log_warn("<<< MGB_LOCK_GUARD ApplyOp");
return; return;
} }
} }
...@@ -1210,8 +1227,10 @@ void ChannelImpl::process_one_task(Command& icmd) { ...@@ -1210,8 +1227,10 @@ void ChannelImpl::process_one_task(Command& icmd) {
} }
cmd.dest->ptr->fetch_value(); cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD GetValue");
notify_tensor_unsafe(cmd.dest); notify_tensor_unsafe(cmd.dest);
imperative_log_profile_end("GetValue"); imperative_log_profile_end("GetValue");
//mgb_log_warn("<<< MGB_LOCK_GUARD GetValue");
} else if constexpr (std::is_same_v<T, Drop>) { } else if constexpr (std::is_same_v<T, Drop>) {
if (cmd.dest->invalid) if (cmd.dest->invalid)
return; return;
...@@ -1271,6 +1290,7 @@ void ChannelImpl::process_one_task(Command& icmd) { ...@@ -1271,6 +1290,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
cmd_visitor(cmd); cmd_visitor(cmd);
} catch (...) { } catch (...) {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD catch exception");
if constexpr (std::is_same_v<T, ApplyOp>) { if constexpr (std::is_same_v<T, ApplyOp>) {
for (auto oup : cmd.outputs) { for (auto oup : cmd.outputs) {
oup->invalid = true; oup->invalid = true;
...@@ -1283,6 +1303,7 @@ void ChannelImpl::process_one_task(Command& icmd) { ...@@ -1283,6 +1303,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
if (m_waitee) { if (m_waitee) {
notify_tensor_unsafe(m_waitee); notify_tensor_unsafe(m_waitee);
} }
//mgb_log_warn("<<< MGB_LOCK_GUARD catch exception");
} }
}, },
icmd.data); icmd.data);
......
...@@ -33,9 +33,8 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( ...@@ -33,9 +33,8 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
tensor.format().to_string().c_str(), tensor.format().to_string().c_str(),
Format(target).to_string().c_str()); Format(target).to_string().c_str());
} }
auto output = imperative::apply( auto output =
*Dimshuffle::make(pattern, scope), imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0];
SmallVector<ValueRef>{tensor.value()})[0];
return m_value_type.make(output, target); return m_value_type.make(output, target);
} }
...@@ -90,6 +89,27 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { ...@@ -90,6 +89,27 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
} }
} }
std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape) {
auto out = std::vector<int32_t>(shape);
if (shape.size() == 4) {
out[1] = shape[2];
out[2] = shape[3];
out[3] = shape[1];
return out;
} else if (shape.size() == 5) {
// GIOHW -> GIHWO
out[2] = shape[3];
out[3] = shape[4];
out[4] = shape[2];
return out;
} else {
mgb_throw(
MegBrainError,
"Unsupported shape ndim %u in convert NCHW shape to NHWC.",
shape.size());
}
}
using FormatRule = std::function<ValueRefList( using FormatRule = std::function<ValueRefList(
const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>; const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>;
static std::unordered_map<Typeinfo*, FormatRule> format_rules; static std::unordered_map<Typeinfo*, FormatRule> format_rules;
...@@ -156,15 +176,30 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { ...@@ -156,15 +176,30 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) {
ValueRefList reshape_rule( ValueRefList reshape_rule(
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert, const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) { const FormatTransformation& t) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() >= 1);
auto& src = inputs[0].cast(t.value_type()); auto& src = inputs[0].cast(t.value_type());
if (auto_convert && src.format() == FT::NHWC) { if (auto_convert && src.format() == FT::NHWC) {
if (inputs.size() == 1) {
if (op.shape.size() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape);
auto outputs = imperative::apply(
*Reshape::make(op.axis, nhwc_shape), {t.unwrap_input(inputs[0])});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(op, {nchw_src});
return t.wrap_outputs(outputs);
}
} else if (inputs.size() == 2) {
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
if (shape.layout().total_nr_elems() == 4) { if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format // output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply( auto outputs = imperative::apply(
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); op,
SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
return t.wrap_outputs(outputs, FT::NHWC); return t.wrap_outputs(outputs, FT::NHWC);
} else { } else {
// will not maintain src's format // will not maintain src's format
...@@ -174,21 +209,37 @@ ValueRefList reshape_rule( ...@@ -174,21 +209,37 @@ ValueRefList reshape_rule(
return t.wrap_outputs(outputs); return t.wrap_outputs(outputs);
} }
} }
}
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
} }
ValueRefList broadcast_rule( ValueRefList broadcast_rule(
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert, const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) { const FormatTransformation& t) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() >= 1);
auto& src = inputs[0].cast(t.value_type()); auto& src = inputs[0].cast(t.value_type());
if (auto_convert && src.format() == FT::NHWC) { if (auto_convert && src.format() == FT::NHWC) {
if (inputs.size() == 1) {
if (op.shape.size() == 4) {
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape);
auto outputs = imperative::apply(
*Broadcast::make(nhwc_shape), {t.unwrap_input(inputs[0])});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value();
auto outputs = imperative::apply(op, {nchw_src});
return t.wrap_outputs(outputs);
}
} else if (inputs.size() == 2) {
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
if (shape.layout().total_nr_elems() == 4) { if (shape.layout().total_nr_elems() == 4) {
// output is still NHWC format // output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
auto outputs = imperative::apply( auto outputs = imperative::apply(
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); op,
SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
return t.wrap_outputs(outputs, FT::NHWC); return t.wrap_outputs(outputs, FT::NHWC);
} else { } else {
// will not maintain src's format // will not maintain src's format
...@@ -198,6 +249,7 @@ ValueRefList broadcast_rule( ...@@ -198,6 +249,7 @@ ValueRefList broadcast_rule(
return t.wrap_outputs(outputs); return t.wrap_outputs(outputs);
} }
} }
}
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
} }
...@@ -240,8 +292,7 @@ ValueRefList subtensor_rule( ...@@ -240,8 +292,7 @@ ValueRefList subtensor_rule(
// only support NHWC2NCHW convert, otherwise maintain src's format // only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) { if (!(auto_convert && src.format() == FT::NHWC)) {
return {t.wrap_output( return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0], imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())};
src.format())};
} }
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
auto outputs = imperative::apply( auto outputs = imperative::apply(
...@@ -263,8 +314,7 @@ ValueRefList setsubtensor_rule( ...@@ -263,8 +314,7 @@ ValueRefList setsubtensor_rule(
// only support NHWC2NCHW convert, otherwise maintain src's format // only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) { if (!(auto_convert && src.format() == FT::NHWC)) {
return {t.wrap_output( return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0], imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())};
src.format())};
} }
// value has been broadcasted to src's fake NCHW shape. // value has been broadcasted to src's fake NCHW shape.
auto& value = inputs[1].cast(t.value_type()); auto& value = inputs[1].cast(t.value_type());
...@@ -329,8 +379,7 @@ ValueRefList identity_rule_helper( ...@@ -329,8 +379,7 @@ ValueRefList identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) {
// mgb_assert(inputs.size() == 1); // mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast(t.value_type()); auto& src = inputs[0].cast(t.value_type());
return t.wrap_outputs( return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format());
imperative::apply(op, t.unwrap_inputs(inputs)), src.format());
} }
ValueRefList batchnorm_rule( ValueRefList batchnorm_rule(
...@@ -457,6 +506,7 @@ struct FormatRuleRegistry { ...@@ -457,6 +506,7 @@ struct FormatRuleRegistry {
ValueRefList FormatTransformation::apply_transformation( ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) { const Operator& op, Span<ValueRef> inputs) {
//mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
if (auto* apply_op = op.as<ApplyOp>()) { if (auto* apply_op = op.as<ApplyOp>()) {
// all inputs should be FormattedTensorValue // all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
...@@ -485,7 +535,7 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -485,7 +535,7 @@ ValueRefList FormatTransformation::apply_transformation(
} }
case GetAttr::Value: { case GetAttr::Value: {
auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); auto nchw_src = unwrap_input(to(src, FT::NCHW, ""));
return imperative::apply(op, SmallVector<ValueRef>{nchw_src}); return imperative::apply(op, {nchw_src});
} }
default: default:
return imperative::apply(op, unwrap_inputs(inputs)); return imperative::apply(op, unwrap_inputs(inputs));
...@@ -508,8 +558,7 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -508,8 +558,7 @@ ValueRefList FormatTransformation::apply_transformation(
auto&& inp_ref = inputs[0].as_ref(m_value_type); auto&& inp_ref = inputs[0].as_ref(m_value_type);
if (inp_ref) { if (inp_ref) {
auto&& format = inp_ref->format(); auto&& format = inp_ref->format();
return wrap_outputs( return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
imperative::apply(op, unwrap_inputs(inputs)), format);
} else { } else {
mgb_log_warn( mgb_log_warn(
"Not FormattedTensorValue input for IdentityLike op: %s, %s", "Not FormattedTensorValue input for IdentityLike op: %s, %s",
...@@ -522,6 +571,7 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -522,6 +571,7 @@ ValueRefList FormatTransformation::apply_transformation(
auto format = inp_ref->format(); auto format = inp_ref->format();
GenericFunction callback = GenericFunction callback =
(GenericFunction&)inputs[1].cast<FunctionValue>(); (GenericFunction&)inputs[1].cast<FunctionValue>();
// make param grads as FormattedTensor
GenericFunction new_callback = GenericFunction new_callback =
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { [this, callback, format](Span<ValueRef> inputs_) -> ValueRefList {
auto wrapped_inputs = SmallVector<ValueRef>{ auto wrapped_inputs = SmallVector<ValueRef>{
...@@ -531,6 +581,7 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -531,6 +581,7 @@ ValueRefList FormatTransformation::apply_transformation(
}; };
auto&& outputs = imperative::apply( auto&& outputs = imperative::apply(
op, inp_ref->value(), FunctionValue::make(new_callback)); op, inp_ref->value(), FunctionValue::make(new_callback));
// make params(GradValue) as FormattedTensor
return wrap_outputs(outputs, format); return wrap_outputs(outputs, format);
} else { } else {
mgb_log_warn( mgb_log_warn(
...@@ -539,6 +590,7 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -539,6 +590,7 @@ ValueRefList FormatTransformation::apply_transformation(
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
} else if (auto* set_grad = op.as<SetGrad>()) { } else if (auto* set_grad = op.as<SetGrad>()) {
// make grads in Function backward as FormattedTensor
size_t nr_inputs = set_grad->nr_inputs(); size_t nr_inputs = set_grad->nr_inputs();
size_t nr_outputs = inputs.size() - nr_inputs; size_t nr_outputs = inputs.size() - nr_inputs;
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; Span<ValueRef> inputs_ = {inputs.data(), nr_inputs};
......
...@@ -377,8 +377,6 @@ public: ...@@ -377,8 +377,6 @@ public:
SetGrad(GenericFunction grad_fn, size_t nr_inputs) SetGrad(GenericFunction grad_fn, size_t nr_inputs)
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {}
std::shared_ptr<GradKey> key() const { return m_key; }
GenericFunction grad_fn() const { return m_grad_fn; } GenericFunction grad_fn() const { return m_grad_fn; }
size_t nr_inputs() const { return m_nr_inputs; } size_t nr_inputs() const { return m_nr_inputs; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册