提交 94dba16f 编写于 作者: M Megvii Engine Team

perf(mge/imperative): misc optimizations

GitOrigin-RevId: bbe7a10b007e8b6d8a66dd64ff83ac1df4b9f8d2
上级 9f139562
......@@ -22,7 +22,14 @@ class Device:
else:
self._cn = CompNode(device)
self.logical_name = self._cn.logical_name
self._logical_name = None
@property
def logical_name(self):
if self._logical_name:
return self._logical_name
self._logical_name = self._cn.logical_name
return self._logical_name
def to_c(self):
return self._cn
......@@ -39,7 +46,7 @@ class Device:
def __eq__(self, rhs):
if not isinstance(rhs, Device):
rhs = Device(rhs)
return str(self._cn) == str(rhs._cn)
return self._cn == rhs._cn
def device(obj):
......
......@@ -28,6 +28,7 @@ from ..ops.builtin import (
from ..ops.special import Const
from ..tensor.core import apply
from ..tensor.function import Function
from ..tensor.tensor import Tensor
from ..tensor.tensor_wrapper import TensorWrapper
_reduce_sum_param = Reduce(mode="SUM").to_c().param[0]
......@@ -103,8 +104,8 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad):
def get_shape(x):
(s,) = apply(GetVarShape(), x)
return s
(s,) = apply(GetVarShape(), x._data)
return Tensor(s)
# override for Elemwise.add
......
......@@ -387,16 +387,19 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
if not manager._enabled:
return
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
# register backward method
# tuple of backward functions corresponding to dy / dx_i
# None means y is not a function of x_i
opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
op, ctx.inputs, ctx.outputs, input_requires_grad
)
assert len(ctx.outputs) == len(output_need_grad)
if not any(output_need_grad):
return
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
opnode.backward = backward
assert len(outputs) == len(output_need_grad)
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
opnode.backward_allow_noinput = check_backward_allow_noinput(op)
......
......@@ -55,6 +55,8 @@ class Tensor(TensorBase):
class ApplyContext:
__slots__ = ("inputs", "outputs", "key")
def __init__(self):
self.inputs = None
self.outputs = None
......@@ -81,7 +83,7 @@ def get_context():
@apply.register()
def tensor_apply(op: OpBase, *args: Tensor):
data = tuple(i._data if isinstance(i, Tensor) else i for i in args)
data = tuple(i._data for i in args)
# type(Tensor._data) is RawTensor
# dispached to apply.add@RawTensor.py if passed Tensor args
outputs = apply(op, *data)
......@@ -90,7 +92,7 @@ def tensor_apply(op: OpBase, *args: Tensor):
with push_context() as ctx:
ctx.inputs = args
ctx.outputs = ret
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
for k in set().union(*(i._extra_data for i in args)):
ctx.key = k
data = tuple(
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args
......
......@@ -229,7 +229,7 @@ def mean(
[3.5]
"""
return inp.astype("float32").mean(axis=axis, keepdims=keepdims)
return inp.mean(axis=axis, keepdims=keepdims)
def var(
......
......@@ -35,15 +35,14 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats
self.freeze = freeze
tshape = (1, self.num_features, 1, 1)
if self.affine:
self.weight = Parameter(np.ones(num_features, dtype=np.float32))
self.bias = Parameter(np.zeros(num_features, dtype=np.float32))
self.weight = Parameter(np.ones(tshape, dtype=np.float32))
self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
else:
self.weight = None
self.bias = None
tshape = (1, self.num_features, 1, 1)
if self.track_running_stats:
self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
......@@ -86,10 +85,8 @@ class _BatchNorm(Module):
inp = inp.reshape(new_shape)
if self.freeze and self.training and self._track_running_stats_saved:
scale = self.weight.reshape(1, -1, 1, 1) * (
self.running_var + self.eps
) ** (-0.5)
bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale
scale = self.weight * (self.running_var + self.eps) ** (-0.5)
bias = self.bias - self.running_mean * scale
return inp * scale.detach() + bias.detach()
if self.training and self.track_running_stats:
......@@ -276,7 +273,7 @@ class BatchNorm2d(_BatchNorm):
m = M.BatchNorm2d(4)
inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
oup = m(inp)
print(m.weight.numpy(), m.bias.numpy())
print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
# Without L`e`arnable Parameters
m = M.BatchNorm2d(4, affine=False)
oup = m(inp)
......
......@@ -55,6 +55,14 @@ def _is_module(obj):
return isinstance(obj, Module)
def _get_XNorm_typeclass():
from .batchnorm import _BatchNorm
XNorm_types = []
XNorm_types.append(_BatchNorm)
return tuple(XNorm_types)
class Module(metaclass=ABCMeta):
"""
Base Module class.
......@@ -393,6 +401,18 @@ class Module(metaclass=ABCMeta):
return offset
def state_dict(self, rst=None, prefix="", keep_var=False):
_rst = self._state_dict(rst=rst, prefix=prefix, keep_var=keep_var)
rst = OrderedDict()
XNorm_typeclass = _get_XNorm_typeclass()
for (module_type, k), v in _rst.items():
# for performance reasons, parameters in XNorm (e.g., BatchNorm2d) are 4-dim tensors,
# however they will be reshaped to 1-dim tensors before returned by `statr_dict()`
if issubclass(module_type, XNorm_typeclass):
v = v.reshape(-1)
rst[k] = v
return rst
def _state_dict(self, rst=None, prefix="", keep_var=False):
r"""
Returns a dictionary containing whole states of the module.
"""
......@@ -400,15 +420,16 @@ class Module(metaclass=ABCMeta):
def is_state(obj):
return _is_parameter(obj) or _is_buffer(obj)
module_type = self.__class__
if rst is None:
rst = OrderedDict()
for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state):
assert prefix + k not in rst, "duplicated state: {}".format(k)
if keep_var:
rst[prefix + k] = v
rst[(module_type, prefix + k)] = v
else:
rst[prefix + k] = v.numpy()
rst[(module_type, prefix + k)] = v.numpy()
for k, submodule in self._flatten(
recursive=False,
......@@ -507,13 +528,14 @@ class Module(metaclass=ABCMeta):
Advance state_dict load through callable ``closure`` whose signature is
``closure(key: str, var: Tensor) -> Union[np.ndarry, None]``
"""
XNorm_typeclass = _get_XNorm_typeclass()
assert callable(closure), "closure must be a function"
loaded = []
skipped = []
local_state_dict = self.state_dict(keep_var=True)
for k, var in local_state_dict.items():
local_state_dict = self._state_dict(keep_var=True)
for (module_type, k), var in local_state_dict.items():
to_be_load = closure(k, var)
if to_be_load is None:
skipped.append(k)
......@@ -523,11 +545,27 @@ class Module(metaclass=ABCMeta):
), "closure should return a `np.ndarray`, now `{}` get {}".format(
k, to_be_load
)
assert make_shape_tuple(var.shape) == make_shape_tuple(
to_be_load.shape
), "param `{}` shape mismatch, should be {}, get {}".format(
k, var.shape, to_be_load.shape
)
var_shape = make_shape_tuple(var.shape)
to_be_load_shape = make_shape_tuple(to_be_load.shape)
if var_shape != to_be_load_shape:
# weight and bias in BatchNorm1d, BatchNorm2d and SyncBatchNorm are 1-dim tensors in v1.0, and
# since v1.1 they are 4-dim tensors. The following special rule for these modules preserves the
# backward compatibility.
if issubclass(module_type, XNorm_typeclass):
if np.prod(var_shape) == np.prod(to_be_load_shape):
to_be_load = to_be_load.reshape(var_shape)
else:
raise ValueError(
"param `{}` size mismatch, should be {}, get {}".format(
k, np.prod(var_shape), np.prod(to_be_load_shape)
)
)
else:
raise ValueError(
"param `{}` shape mismatch, should be {}, get {}".format(
k, var_shape, to_be_load_shape
)
)
var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device))
loaded.append(k)
......
......@@ -193,7 +193,11 @@ def run_train(
net.state_dict().items(), checkpoint["net_updated"].items()
):
assert param[0] == param_ref[0]
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
if "bn" in param[0]:
ref = param_ref[1].reshape(param[1].shape)
np.testing.assert_allclose(param[1], ref, atol=max_err)
else:
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
def run_eval(
......
......@@ -188,7 +188,11 @@ def run_test(
net.state_dict().items(), checkpoint["net_updated"].items()
):
assert param[0] == param_ref[0]
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
if "bn" in param[0]:
ref = param_ref[1].reshape(param[1].shape)
np.testing.assert_allclose(param[1], ref, atol=max_err)
else:
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
procs = []
for rank in range(p_num):
......
......@@ -107,7 +107,7 @@ private:
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int m_async_level = 1;
int m_async_level = 2;
};
} // namespace mgb::imperative::interpreter::intl
......@@ -94,7 +94,7 @@ private:
cg::OperatorNodeBase* m_cur_opr = nullptr;
std::unique_ptr<ProxyGraphImpl> m_graph;
size_t m_max_op_cnt = 1000;
size_t m_max_op_cnt = 100;
std::unique_ptr<ExecEnv> m_env;
std::unique_ptr<StaticInferManager> m_static_infer_manager;
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer;
......
......@@ -120,12 +120,12 @@ make_backward_graph(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
auto&& graph = ProxyGraph::get_default_graph();
auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad);
auto&& iter = backward_graph_cache.find(hash_key);
if (iter != backward_graph_cache.end()) {
return iter->second;
}
auto&& graph = ProxyGraph::get_default_graph();
auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
backward_graph_cache.emplace(hash_key, res);
return res;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册