提交 8563f514 编写于 作者: M Megvii Engine Team

fix(imperative): fix buildin reduce keepdim

GitOrigin-RevId: 38d90ab38adeaa139c420109a558b630125bcace
上级 96d90be1
...@@ -7,6 +7,8 @@ from utils import opr_test ...@@ -7,6 +7,8 @@ from utils import opr_test
import megengine.functional as F import megengine.functional as F
from megengine import jit, tensor from megengine import jit, tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
def common_test_reduce(opr, ref_opr): def common_test_reduce(opr, ref_opr):
...@@ -182,6 +184,21 @@ def test_sum_neg_axis(): ...@@ -182,6 +184,21 @@ def test_sum_neg_axis():
F.sum(tensor(data), axis=(-1, 1)) F.sum(tensor(data), axis=(-1, 1))
def test_builtin_reduce():
shape = (2, 3, 3, 2)
data = np.random.random(shape).astype(np.float32)
for axis in (-1, -2, 0, 1):
for keepdims in (True, False):
op = builtin.Reduce(mode="sum", axis=axis, keepdim=keepdims)
get = apply(op, tensor(data))[0]
def_op = builtin.Reduce(mode="sum", axis=axis)
def_get = apply(def_op, tensor(data))[0]
ref = np.sum(data, axis=axis, keepdims=keepdims)
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
if keepdims == True:
np.testing.assert_allclose(def_get.numpy(), ref, rtol=1e-6)
def test_non_finite(): def test_non_finite():
shape = (32, 3, 32, 32) shape = (32, 3, 32, 32)
data = [] data = []
......
...@@ -222,7 +222,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -222,7 +222,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
dests[i].comp_node = inputs[i].comp_node; dests[i].comp_node = inputs[i].comp_node;
dests[i].layout = inputs[i].layout; dests[i].layout = inputs[i].layout;
if (not keepdim && dests[i].layout.ndim > 1) { if (!keepdim && dests[i].layout.ndim > 1) {
dests[i].layout.remove_axis_inplace(axis); dests[i].layout.remove_axis_inplace(axis);
} else { } else {
dests[i].layout.shape[axis] = 1; dests[i].layout.shape[axis] = 1;
......
...@@ -16,7 +16,7 @@ def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { ...@@ -16,7 +16,7 @@ def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>{ def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>{
let extraArguments = (ins let extraArguments = (ins
MgbBoolAttr:$keepdim MgbDefaultValuedAttr<MgbBoolAttr, "true">:$keepdim
); );
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册