提交 06ba043a 编写于 作者: M Megvii Engine Team

fix(imperative): fix subtensor some error

GitOrigin-RevId: bcc0307d67c66b4b7c9237775cdbe3b4360fdef5
上级 a60ad267
...@@ -435,7 +435,8 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: ...@@ -435,7 +435,8 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
""" """
op = builtin.FillLike(value=value) op = builtin.FillLike(value=value)
(rst,) = apply(op, inp) (rst,) = apply(op, inp)
rst.format = inp.format # rst.format = inp.format
# see jira:MGE-4505
return rst return rst
......
...@@ -1208,6 +1208,11 @@ py::object _fastpath_getitem_cpp(py::handle inp_hdl, py::tuple tuple_val) { ...@@ -1208,6 +1208,11 @@ py::object _fastpath_getitem_cpp(py::handle inp_hdl, py::tuple tuple_val) {
ax += 1; ax += 1;
} else if (PyBool_Check(t.ptr())) { } else if (PyBool_Check(t.ptr())) {
expand_items.push_back(ax); expand_items.push_back(ax);
if (t.ptr() == Py_False) {
cpp_items.push_back({ax, true, true, true, false});
slice_items.push_back({0, 0, 1, INT_MAX});
}
ax += 1;
} else if (t.ptr() == Py_None) { } else if (t.ptr() == Py_None) {
expand_items.push_back(ax); expand_items.push_back(ax);
ax += 1; ax += 1;
......
...@@ -342,6 +342,46 @@ def test_subtensor(): ...@@ -342,6 +342,46 @@ def test_subtensor():
np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy() np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy()
) )
x_np = np.random.rand(3, 2).astype("float32")
x = mge.Tensor(x_np)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}
def f(x):
x = x * 1
y = x[True, 0:1]
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 1], [0, 0], [0, 0]], dtype=np.float32), x.grad.numpy()
)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}
def f(x):
x = x * 1
y = x[False, 0:1]
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32), x.grad.numpy()
)
def test_IndexingMultiAxisVec(): def test_IndexingMultiAxisVec():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
......
...@@ -84,8 +84,8 @@ TensorLayout deduce_layout( ...@@ -84,8 +84,8 @@ TensorLayout deduce_layout(
return 0; return 0;
return v < 0 ? v + size_ax : v; return v < 0 ? v + size_ax : v;
}; };
#define CHECK(cond) \
mgb_assert(cond, "index out of bound: layout=%s", src.to_string().c_str()) auto tostr = [](int v) -> std::string { return std::to_string(v); };
for (int i = items.size() - 1; i >= 0; i--) { for (int i = items.size() - 1; i >= 0; i--) {
auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i]; auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i];
...@@ -99,16 +99,28 @@ TensorLayout deduce_layout( ...@@ -99,16 +99,28 @@ TensorLayout deduce_layout(
slice_stop = mod_size(slice_stop, shape_axis); slice_stop = mod_size(slice_stop, shape_axis);
slice_stop = std::min(slice_stop, shape_axis); slice_stop = std::min(slice_stop, shape_axis);
slice_start = std::min(slice_start, slice_stop); slice_start = std::min(slice_start, slice_stop);
CHECK(slice_start >= 0 && slice_stop >= slice_start && mgb_assert(
slice_stop <= shape_axis); (slice_start >= 0 && slice_stop >= slice_start &&
slice_stop <= shape_axis),
"index out of bound: layout=%s; request begin=%s end=%s step=%s "
"axis=%s",
src.to_string().c_str(), tostr(slice_start).c_str(),
tostr(slice_stop).c_str(), tostr(slice_step).c_str(),
tostr(axis).c_str());
} else { } else {
slice_start = s_val == INT_MIN ? shape_axis - 1 : b_val; slice_start = s_val == INT_MIN ? shape_axis - 1 : b_val;
slice_start = mod_size(slice_start, shape_axis); slice_start = mod_size(slice_start, shape_axis);
slice_stop = e_val == INT_MAX ? -1 : mod_size(e_val, shape_axis); slice_stop = e_val == INT_MAX ? -1 : mod_size(e_val, shape_axis);
slice_start = std::min(slice_start, std::max(shape_axis - 1, 0)); slice_start = std::min(slice_start, std::max(shape_axis - 1, 0));
slice_stop = std::min(slice_stop, slice_start); slice_stop = std::min(slice_stop, slice_start);
CHECK(slice_step < 0 && slice_start >= 0 && slice_stop <= slice_start && mgb_assert(
slice_start < shape_axis && slice_stop >= -1); (slice_step < 0 && slice_start >= 0 && slice_stop <= slice_start &&
slice_start < shape_axis && slice_stop >= -1),
"index out of bound: layout=%s; request begin=%s end=%s step=%s "
"axis=%s",
src.to_string().c_str(), tostr(slice_start).c_str(),
tostr(slice_stop).c_str(), tostr(slice_step).c_str(),
tostr(axis).c_str());
} }
int abs_step = std::abs(slice_step); int abs_step = std::abs(slice_step);
if (axis < 0) { if (axis < 0) {
...@@ -205,7 +217,7 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( ...@@ -205,7 +217,7 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
return layout_checker; return layout_checker;
} }
OP_TRAIT_REG(Subtensor, Subtensor, opr::Subtensor) OP_TRAIT_REG(Subtensor, Subtensor)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
......
...@@ -369,6 +369,27 @@ ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) { ...@@ -369,6 +369,27 @@ ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
ValueRefList group_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
mgb::DType target_dtype = DTypePromoteCfg::amp_high_prec_dtype;
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
return imperative::apply(op, inputs);
}
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs); SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype = get_promoted_dtype(dtypes); mgb::DType target_dtype = get_promoted_dtype(dtypes);
...@@ -402,6 +423,7 @@ struct DTypePromoteRuleRegistry { ...@@ -402,6 +423,7 @@ struct DTypePromoteRuleRegistry {
register_dtype_promote_rule<Convolution3D>(naive_promote_rule); register_dtype_promote_rule<Convolution3D>(naive_promote_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule); register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule);
register_dtype_promote_rule<LayerNorm>(layer_norm_rule); register_dtype_promote_rule<LayerNorm>(layer_norm_rule);
register_dtype_promote_rule<GroupNorm>(group_norm_rule);
} }
} register_helper; } register_helper;
......
...@@ -549,6 +549,7 @@ ValueRefList adaptive_pooling_rule( ...@@ -549,6 +549,7 @@ ValueRefList adaptive_pooling_rule(
cb(FastpathCopy) \ cb(FastpathCopy) \
cb(TypeCvt) \ cb(TypeCvt) \
cb(Dropout) \ cb(Dropout) \
cb(FillLike) \
cb(Identity) cb(Identity)
#define FOREACH_FORMAT_OP(cb) \ #define FOREACH_FORMAT_OP(cb) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册