提交 d313f926 编写于 作者: M Megvii Engine Team

fix(imperative/amp): fix format transformation for symbol trans

GitOrigin-RevId: 96cc237c67e25c8cb1567eb08325db65adc1c57d
上级 261a5bce
...@@ -260,7 +260,6 @@ class GradManager: ...@@ -260,7 +260,6 @@ class GradManager:
push_scope("backward") push_scope("backward")
set_option("record_computing_path", 0) set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert() _origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
from ..functional import ones_like from ..functional import ones_like
global backwarding_grad_manager global backwarding_grad_manager
...@@ -304,7 +303,6 @@ class GradManager: ...@@ -304,7 +303,6 @@ class GradManager:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
set_option("record_computing_path", 1) set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
pop_scope("backward") pop_scope("backward")
def record(self): def record(self):
......
...@@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: ...@@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return x return x
# set x's format to use FormatTransformation rule for Broadcast. # set x's format to use FormatTransformation rule for Broadcast.
return broadcast_to(x, inp.shape) rst = broadcast_to(x, inp.shape)
rst.format = inp.format
return rst
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
......
...@@ -26,7 +26,7 @@ public: ...@@ -26,7 +26,7 @@ public:
Eval, Eval,
}; };
std::array<std::vector<std::shared_ptr<Transformation>>, 8> segments; std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments;
private: private:
template <Segment segment> template <Segment segment>
......
...@@ -12,6 +12,7 @@ import megengine.functional as F ...@@ -12,6 +12,7 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Parameter, Tensor, amp from megengine import Parameter, Tensor, amp
from megengine.core._config import set_auto_format_convert from megengine.core._config import set_auto_format_convert
from megengine.core._trace_option import use_symbolic_shape
class MyModule(M.Module): class MyModule(M.Module):
...@@ -41,22 +42,25 @@ class MyModule(M.Module): ...@@ -41,22 +42,25 @@ class MyModule(M.Module):
def test_convert_module(is_inplace): def test_convert_module(is_inplace):
m = MyModule() m = MyModule()
expected_shape = { expected_shape = {
"i.bn.weight": (1, 1, 1, 4), "i.bn.weight": (1, 4, 1, 1),
"i.bn.bias": (1, 1, 1, 4), "i.bn.bias": (1, 4, 1, 1),
"i.bn.running_mean": (1, 1, 1, 4), "i.bn.running_mean": (1, 4, 1, 1),
"i.bn.running_var": (1, 1, 1, 4), "i.bn.running_var": (1, 4, 1, 1),
"conv.weight": (2, 2, 4, 4, 2), "conv.weight": (2, 2, 2, 4, 4),
"conv.bias": (1, 1, 1, 4), "conv.bias": (1, 4, 1, 1),
"bn.weight": (1, 1, 1, 4), "bn.weight": (1, 4, 1, 1),
"bn.bias": (1, 1, 1, 4), "bn.bias": (1, 4, 1, 1),
"bn.running_mean": (1, 1, 1, 4), "bn.running_mean": (1, 4, 1, 1),
"bn.running_var": (1, 1, 1, 4), "bn.running_var": (1, 4, 1, 1),
"param": (1, 1, 1, 3), "param": (1, 3, 1, 1),
"buff": (1, 1, 1, 3), "buff": (1, 3, 1, 1),
} }
m = amp.convert_module_format(m, is_inplace) m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors(): for name, param in m.named_tensors():
assert param.format == "nhwc" assert param.format == "nhwc"
set_auto_format_convert(False) if use_symbolic_shape():
assert param.shape == expected_shape[name], name np.testing.assert_array_equal(
set_auto_format_convert(True) param.shape.numpy(), expected_shape[name], name
)
else:
assert param.shape == expected_shape[name], name
...@@ -6,6 +6,7 @@ import megengine.functional as F ...@@ -6,6 +6,7 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import tensor from megengine import tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.core._trace_option import use_symbolic_shape
from megengine.jit import trace from megengine.jit import trace
...@@ -121,7 +122,10 @@ def test_repeat(is_symbolic): ...@@ -121,7 +122,10 @@ def test_repeat(is_symbolic):
@pytest.mark.parametrize("is_symbolic", [None]) @pytest.mark.parametrize("is_symbolic", [None])
def test_getshape(is_symbolic): def test_getshape(is_symbolic):
def func(x): def func(x):
return x.shape if use_symbolic_shape():
return x.shape.numpy()
else:
return x.shape
data = np.arange(0, 24).reshape((1, 2, 3, 4)) data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func, is_symbolic) _compare_nchw_nhwc(data, func, is_symbolic)
......
#include "megbrain/imperative/transformations/format.h" #include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/grad.h" #include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
...@@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs( ...@@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs(
} }
return wrapped_outputs; return wrapped_outputs;
} }
inline bool FormatTransformation::check_all_format_value(
const Span<ValueRef>& inputs) const {
for (size_t i = 0; i < inputs.size(); ++i) {
if (!inputs[i].as_ref(m_value_type)) {
return false;
}
}
return true;
}
namespace { namespace {
ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
...@@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format( ...@@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format(
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type()); auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != dst_fmt && if (inp.format() != dst_fmt &&
inp.value().shape().cast<ShapeValue>().ndim == 4) { (inp.value().shape().cast<ShapeValue>().ndim == 4 ||
inp.value().shape().cast<ShapeValue>().ndim == 5)) {
unified_inputs[i] = t.to(inp, dst_fmt, scope); unified_inputs[i] = t.to(inp, dst_fmt, scope);
} else { } else {
unified_inputs[i] = inputs[i]; unified_inputs[i] = inputs[i];
...@@ -568,6 +581,10 @@ struct FormatRuleRegistry { ...@@ -568,6 +581,10 @@ struct FormatRuleRegistry {
ValueRefList FormatTransformation::apply_transformation( ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) { const Operator& op, Span<ValueRef> inputs) {
if (auto* apply_op = op.as<ApplyOp>()) { if (auto* apply_op = op.as<ApplyOp>()) {
// bypass SymbolValue
if (!check_all_format_value(inputs)) {
return imperative::apply(op, unwrap_inputs(inputs));
}
// all inputs should be FormattedTensorValue // all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
if (iter != format_rules.end()) { if (iter != format_rules.end()) {
...@@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation(
auto&& format = inp_ref->format(); auto&& format = inp_ref->format();
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
} else { } else {
mgb_log_warn(
"Not FormattedTensorValue input for IdentityLike op: %s, %s",
op.to_string().c_str(), inputs[0].to_string().c_str());
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
} else if (op.is<AttachGrad>()) { } else if (op.is<AttachGrad>()) {
......
...@@ -70,6 +70,7 @@ public: ...@@ -70,6 +70,7 @@ public:
const ValueRef& output, Format format = Format::Type::DEFAULT) const; const ValueRef& output, Format format = Format::Type::DEFAULT) const;
inline ValueRefList wrap_outputs( inline ValueRefList wrap_outputs(
const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const; const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const;
inline bool check_all_format_value(const Span<ValueRef>& inputs) const;
TypedValueRef<FormattedTensorValue> as( TypedValueRef<FormattedTensorValue> as(
const FormattedTensorValue&, const Format::Type& target) const; const FormattedTensorValue&, const Format::Type& target) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册