提交 261a5bce 编写于 作者: M Megvii Engine Team

feat(imperative/amp): add dimshuffle in set_format for nhwc

GitOrigin-RevId: 5ced9e1a31d78ea0628663ba9e3aa4942255713f
上级 c9e56f49
......@@ -23,23 +23,17 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
if not _is_nchw_format(x):
return x
if x.ndim == 4:
pattern = (0, 2, 3, 1)
elif x.ndim == 5:
pattern = (0, 1, 3, 4, 2)
else:
if x.ndim != 4 and x.ndim != 5:
raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
# TODO: use initialization from tensor after fixing format setting
if x.format != "nhwc":
# hostvalue should still be valid, so no d2h cost.
data = x.numpy()
if inplace:
# hostvalue should still be valid, so no d2h cost.
data = x.numpy()
# reset will destroy existed backward grad
x[...] = Tensor(data, format="nhwc")
else:
# use mge interface to maintain grad
x = F.transpose(x, pattern)
x.format = "nhwc"
x = Tensor(data, format="nhwc")
return x
......
......@@ -181,7 +181,6 @@ def _reset_execution_config(
deterministic_kernel=None,
async_level=None,
compute_mode=None,
auto_format_convert=None,
):
global _benchmark_kernel, _deterministic_kernel, __compute_mode
orig_flags = (
......@@ -189,7 +188,6 @@ def _reset_execution_config(
_deterministic_kernel,
get_option("async_level"),
__compute_mode,
get_auto_format_convert(),
)
if benchmark_kernel is not None:
_benchmark_kernel = benchmark_kernel
......@@ -199,8 +197,6 @@ def _reset_execution_config(
set_option("async_level", async_level)
if compute_mode is not None:
__compute_mode = compute_mode
if auto_format_convert is not None:
set_auto_format_convert(auto_format_convert)
return orig_flags
......@@ -211,7 +207,6 @@ def _override(
deterministic_kernel=None,
async_level=None,
compute_mode=None,
auto_format_convert=None,
):
r"""A context manager that users can opt in by attaching the decorator to set
the config of the global variable.
......@@ -227,7 +222,6 @@ def _override(
deterministic_kernel = Fasle,
async_level=2,
compute_mode="float32",
auto_format_convert=True,
)
def train():
"""
......@@ -236,7 +230,6 @@ def _override(
deterministic_kernel=deterministic_kernel,
async_level=async_level,
compute_mode=compute_mode,
auto_format_convert=auto_format_convert,
)
try:
yield
......
......@@ -1206,9 +1206,9 @@ def batch_norm(
if x is None:
x = Const(value, inp.dtype, inp.device)
x.format = inp.format
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape)
result.format = inp.format
return result
else:
assert x_ndim == 1
......
......@@ -274,7 +274,6 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return x
# set x's format to use FormatTransformation rule for Broadcast.
x.format = inp.format
return broadcast_to(x, inp.shape)
......
......@@ -91,14 +91,13 @@ class Optimizer(metaclass=ABCMeta):
else:
param_group["params"] = list(param_group["params"])
with _config._override(auto_format_convert=False):
for param in param_group["params"]:
if not isinstance(param, Parameter):
raise TypeError(
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
for param in param_group["params"]:
if not isinstance(param, Parameter):
raise TypeError(
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
for name, default in self._defaults.items():
if default is required and name not in param_group:
......@@ -121,8 +120,7 @@ class Optimizer(metaclass=ABCMeta):
def _add_state(self, param, state_name, initializer=None):
if initializer is None:
with _config._override(auto_format_convert=False):
initializer = np.zeros(param.shape, dtype=np.float32)
initializer = np.zeros(param.shape, dtype=np.float32)
state_dict = self._state.setdefault(param, {})
assert state_name not in state_dict
state = Tensor(initializer, no_cache=True, format=param.format)
......
......@@ -10,7 +10,8 @@ import pytest
import megengine.functional as F
import megengine.module as M
from megengine import Parameter, Tensor, amp, config
from megengine import Parameter, Tensor, amp
from megengine.core._config import set_auto_format_convert
class MyModule(M.Module):
......@@ -56,5 +57,6 @@ def test_convert_module(is_inplace):
m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors():
assert param.format == "nhwc"
with config._override(auto_format_convert=False):
assert param.shape == expected_shape[name], name
set_auto_format_convert(False)
assert param.shape == expected_shape[name], name
set_auto_format_convert(True)
......@@ -19,6 +19,9 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
const std::string& scope) const {
std::vector<int32_t> pattern;
Format format = tensor.format();
if (format == target)
return as(tensor, target);
if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) {
// FIXME(czh): temporary fast path for group conv 5D weight.
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
......@@ -618,7 +621,7 @@ ValueRefList FormatTransformation::apply_transformation(
} else if (auto* _op = op.as<SetFormat>()) {
auto&& inp_ref = inputs[0].as_ref(m_value_type);
mgb_assert(inp_ref, "Cannot set format for non-format Tensor.");
return {m_value_type.make(inp_ref->value(), _op->format())};
return {to(*inp_ref, _op->format().type(), "")};
} else if (op.is<Operator::IdentityLike>()) {
auto&& inp_ref = inputs[0].as_ref(m_value_type);
if (inp_ref) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册