提交 cd263765 编写于 作者: M Megvii Engine Team

style(imperative/amp): reformat code

GitOrigin-RevId: 6e5a6e1eaff88e0031ec4812be4f755d999244ab
上级 3892aa0b
......@@ -206,8 +206,8 @@ struct CheckNonFiniteOp {
return lhs | rhs;
}
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(
src_ctype** srcs, size_t* srcs_total_nr_elems, dst_ctype* dst,
size_t B, src_ctype scale)
src_ctype** srcs, size_t* srcs_total_nr_elems, dst_ctype* dst, size_t B,
src_ctype scale)
: INIT(wtype(0)),
srcs(srcs),
srcs_total_nr_elems(srcs_total_nr_elems),
......
......@@ -8,9 +8,9 @@
from copy import deepcopy
from .. import functional as F
from ..core import _config
from ..module import Module
from ..tensor import Tensor
from ..core import _config
def _is_nchw_format(param: Tensor):
......@@ -39,7 +39,7 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
else:
# use mge interface to maintain grad
x = F.transpose(x, pattern)
x.format="nhwc"
x.format = "nhwc"
return x
......
......@@ -134,7 +134,6 @@ def _compute_mode(mod, _compute_mode: str):
__compute_mode = _compute_mode
@property
def _bn_format(mod):
r"""Get or set batchnorm param layout format. The default option is None and will
......
......@@ -5,6 +5,7 @@ from typing import Iterable, Union
import numpy as np
from .. import _config
from .._imperative_rt import make_const
from .._imperative_rt.core2 import (
Const,
......@@ -24,7 +25,6 @@ from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._imperative_rt.ops import jit_supported
from .._wrap import as_device
from ..autodiff.grad import Function
from .. import _config
from ..ops import builtin
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype
from .dtype import is_dtype_equal, is_quantize
......
......@@ -2,10 +2,10 @@
import os
from typing import Iterable, Union
from ..core import _config
from ..functional.inplace import _inplace_add_
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
from ..core import _config
class SGD(Optimizer):
......
......@@ -1024,9 +1024,9 @@ void init_tensor(py::module m) {
using namespace std::placeholders;
self.compiled = std::make_shared<CompiledTransformation>(
*self.trace_result, self.record_input_shapes);
self.compiled->set_value_comparator(
std::bind(&Trace::compare_value, this, _1, _2));
self.options_visitor(py::cast(&self.compiled->options()));
self.compiled->set_value_comparator(
std::bind(&Trace::compare_value, this, _1, _2));
self.options_visitor(py::cast(&self.compiled->options()));
try {
self.compiled->compile();
} catch (const std::exception& e) {
......
......@@ -320,7 +320,8 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) {
}
}
py::object device_obj = device2obj(device, true);
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none(), py::none());
py::tuple tup =
py::make_tuple(val, dtype, device_obj, true, false, py::none(), py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
}
......
......@@ -47,6 +47,10 @@ def test_grad_scaler(is_trace):
return loss
for data in [np.random.random((1, 2, 3, 4)), 1.0]:
for calc in [double_variables, single_variable, double_variables_with_same_grad]:
for calc in [
double_variables,
single_variable,
double_variables_with_same_grad,
]:
for idx in range(3):
f(idx, data, calc)
......@@ -260,7 +260,6 @@ void ChannelImpl::dispatch_default_cpu(
CompNode output_cn;
{
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD dispatch_default_cpu");
for (auto&& info : input_infos) {
auto input_cn = info->desc.comp_node;
if (!output_cn.valid()) {
......@@ -278,7 +277,6 @@ void ChannelImpl::dispatch_default_cpu(
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
}
}
//mgb_log_warn("<<< MGB_LOCK_GUARD dispatch_default_cpu");
}
SmallVector<DeviceTensorND> output_tensornds;
......@@ -532,9 +530,7 @@ void ChannelImpl::sync() {
void ChannelImpl::sync_impl() {
m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD sync_impl");
check_worker_exc_unsafe();
//mgb_log_warn("<<< MGB_LOCK_GUARD sync_impl");
}
void ChannelImpl::close() {
......@@ -693,7 +689,6 @@ ChannelImpl::~ChannelImpl() {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
auto& state = get_worker_state();
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD produce_tensor");
m_dtr.update_used_time(dest);
MGB_RECORD_EVENT(
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
......@@ -720,19 +715,16 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
m_dtr.insert_candidate(dest);
}
notify_tensor_unsafe(dest);
//mgb_log_warn("<<< MGB_LOCK_GUARD produce_tensor");
}
void ChannelImpl::release_tensor(TensorInfo* dest) {
MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD release_tensor");
dest->ptr.reset();
auto& state = get_worker_state();
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(dest);
}
//mgb_log_warn("<<< MGB_LOCK_GUARD release_tensor");
}
void ChannelImpl::regenerate(TensorInfo* dest) {
......@@ -1008,7 +1000,6 @@ bool ChannelImpl::check_available() {
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor");
mgb_assert(!m_waitee, "duplicate waitee");
m_waitee = info;
m_waitee_id = Profiler::next_id();
......@@ -1019,7 +1010,6 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
if (require_host && !host_available()) {
// avoid dead lock
lock.unlock();
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor unlock");
if (Profiler::is_profiling()) {
m_worker.add_task(
{Profiler::next_id(), GetValue{info},
......@@ -1031,21 +1021,18 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
});
}
lock.lock();
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor lock");
wait_host = true;
}
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return require_host ? host_available() : static_cast<bool>(info->ptr);
});
//mgb_log_warn("after cv wait");
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr;
if (wait_host) {
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
}
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor");
return info->ptr;
}
......@@ -1053,7 +1040,6 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
if (info == m_waitee) {
MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
m_cv.notify_all();
//mgb_log_warn("cv notify_all");
}
}
......@@ -1116,7 +1102,6 @@ void ChannelImpl::process_one_task(Command& icmd) {
using namespace ranges::views;
auto& state = get_worker_state();
auto& options = state.options;
//mgb_log_warn("process_one_task %s", to_string<Command>(icmd).c_str());
// TODO: remove std::visit for support osx 10.12
auto cmd_visitor = [&](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
......@@ -1138,11 +1123,9 @@ void ChannelImpl::process_one_task(Command& icmd) {
for (auto& i : cmd.inputs) {
if (mgb_unlikely(i->invalid)) {
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD ApplyOp");
for (auto& i : cmd.outputs) {
i->invalid = true;
}
//mgb_log_warn("<<< MGB_LOCK_GUARD ApplyOp");
return;
}
}
......@@ -1227,10 +1210,8 @@ void ChannelImpl::process_one_task(Command& icmd) {
}
cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD GetValue");
notify_tensor_unsafe(cmd.dest);
imperative_log_profile_end("GetValue");
//mgb_log_warn("<<< MGB_LOCK_GUARD GetValue");
} else if constexpr (std::is_same_v<T, Drop>) {
if (cmd.dest->invalid)
return;
......@@ -1290,7 +1271,6 @@ void ChannelImpl::process_one_task(Command& icmd) {
cmd_visitor(cmd);
} catch (...) {
MGB_LOCK_GUARD(m_mutex);
//mgb_log_warn(">>> MGB_LOCK_GUARD catch exception");
if constexpr (std::is_same_v<T, ApplyOp>) {
for (auto oup : cmd.outputs) {
oup->invalid = true;
......@@ -1303,7 +1283,6 @@ void ChannelImpl::process_one_task(Command& icmd) {
if (m_waitee) {
notify_tensor_unsafe(m_waitee);
}
//mgb_log_warn("<<< MGB_LOCK_GUARD catch exception");
}
},
icmd.data);
......
......@@ -380,7 +380,8 @@ ValueRefList identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) {
// mgb_assert(inputs.size() == 1);
if (auto& src = inputs[0].as_ref(t.value_type())) {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src->format());
return t.wrap_outputs(
imperative::apply(op, t.unwrap_inputs(inputs)), src->format());
} else {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
......
......@@ -49,7 +49,7 @@ public:
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
//mgb_assert(!value.is(m_value_type));
// mgb_assert(!value.is(m_value_type));
if (auto format_val = value.as_ref(m_value_type)) {
return format_val->value();
}
......@@ -69,8 +69,7 @@ public:
inline ValueRef wrap_output(
const ValueRef& output, Format format = Format::Type::DEFAULT) const;
inline ValueRefList wrap_outputs(
const ValueRefList& outputs,
Format format = Format::Type::DEFAULT) const;
const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const;
TypedValueRef<FormattedTensorValue> as(
const FormattedTensorValue&, const Format::Type& target) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册