提交 d313f926 编写于 作者: M Megvii Engine Team

fix(imperative/amp): fix format transformation for symbol trans

GitOrigin-RevId: 96cc237c67e25c8cb1567eb08325db65adc1c57d
上级 261a5bce
......@@ -260,7 +260,6 @@ class GradManager:
push_scope("backward")
set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
from ..functional import ones_like
global backwarding_grad_manager
......@@ -304,7 +303,6 @@ class GradManager:
self.release()
backwarding_grad_manager = cache
set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
pop_scope("backward")
def record(self):
......
......@@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return x
# set x's format to use FormatTransformation rule for Broadcast.
return broadcast_to(x, inp.shape)
rst = broadcast_to(x, inp.shape)
rst.format = inp.format
return rst
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
......
......@@ -26,7 +26,7 @@ public:
Eval,
};
std::array<std::vector<std::shared_ptr<Transformation>>, 8> segments;
std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments;
private:
template <Segment segment>
......
......@@ -12,6 +12,7 @@ import megengine.functional as F
import megengine.module as M
from megengine import Parameter, Tensor, amp
from megengine.core._config import set_auto_format_convert
from megengine.core._trace_option import use_symbolic_shape
class MyModule(M.Module):
......@@ -41,22 +42,25 @@ class MyModule(M.Module):
def test_convert_module(is_inplace):
m = MyModule()
expected_shape = {
"i.bn.weight": (1, 1, 1, 4),
"i.bn.bias": (1, 1, 1, 4),
"i.bn.running_mean": (1, 1, 1, 4),
"i.bn.running_var": (1, 1, 1, 4),
"conv.weight": (2, 2, 4, 4, 2),
"conv.bias": (1, 1, 1, 4),
"bn.weight": (1, 1, 1, 4),
"bn.bias": (1, 1, 1, 4),
"bn.running_mean": (1, 1, 1, 4),
"bn.running_var": (1, 1, 1, 4),
"param": (1, 1, 1, 3),
"buff": (1, 1, 1, 3),
"i.bn.weight": (1, 4, 1, 1),
"i.bn.bias": (1, 4, 1, 1),
"i.bn.running_mean": (1, 4, 1, 1),
"i.bn.running_var": (1, 4, 1, 1),
"conv.weight": (2, 2, 2, 4, 4),
"conv.bias": (1, 4, 1, 1),
"bn.weight": (1, 4, 1, 1),
"bn.bias": (1, 4, 1, 1),
"bn.running_mean": (1, 4, 1, 1),
"bn.running_var": (1, 4, 1, 1),
"param": (1, 3, 1, 1),
"buff": (1, 3, 1, 1),
}
m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors():
assert param.format == "nhwc"
set_auto_format_convert(False)
assert param.shape == expected_shape[name], name
set_auto_format_convert(True)
if use_symbolic_shape():
np.testing.assert_array_equal(
param.shape.numpy(), expected_shape[name], name
)
else:
assert param.shape == expected_shape[name], name
......@@ -6,6 +6,7 @@ import megengine.functional as F
import megengine.module as M
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.core._trace_option import use_symbolic_shape
from megengine.jit import trace
......@@ -121,7 +122,10 @@ def test_repeat(is_symbolic):
@pytest.mark.parametrize("is_symbolic", [None])
def test_getshape(is_symbolic):
def func(x):
return x.shape
if use_symbolic_shape():
return x.shape.numpy()
else:
return x.shape
data = np.arange(0, 24).reshape((1, 2, 3, 4))
_compare_nchw_nhwc(data, func, is_symbolic)
......
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
......@@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs(
}
return wrapped_outputs;
}
inline bool FormatTransformation::check_all_format_value(
const Span<ValueRef>& inputs) const {
for (size_t i = 0; i < inputs.size(); ++i) {
if (!inputs[i].as_ref(m_value_type)) {
return false;
}
}
return true;
}
namespace {
ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
......@@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format(
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != dst_fmt &&
inp.value().shape().cast<ShapeValue>().ndim == 4) {
(inp.value().shape().cast<ShapeValue>().ndim == 4 ||
inp.value().shape().cast<ShapeValue>().ndim == 5)) {
unified_inputs[i] = t.to(inp, dst_fmt, scope);
} else {
unified_inputs[i] = inputs[i];
......@@ -568,6 +581,10 @@ struct FormatRuleRegistry {
ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* apply_op = op.as<ApplyOp>()) {
// bypass SymbolValue
if (!check_all_format_value(inputs)) {
return imperative::apply(op, unwrap_inputs(inputs));
}
// all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
if (iter != format_rules.end()) {
......@@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation(
auto&& format = inp_ref->format();
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
} else {
mgb_log_warn(
"Not FormattedTensorValue input for IdentityLike op: %s, %s",
op.to_string().c_str(), inputs[0].to_string().c_str());
return imperative::apply(op, inputs);
}
} else if (op.is<AttachGrad>()) {
......
......@@ -70,6 +70,7 @@ public:
const ValueRef& output, Format format = Format::Type::DEFAULT) const;
inline ValueRefList wrap_outputs(
const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const;
inline bool check_all_format_value(const Span<ValueRef>& inputs) const;
TypedValueRef<FormattedTensorValue> as(
const FormattedTensorValue&, const Format::Type& target) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册