未验证 提交 ff4bdac3 编写于 作者: Z zyfncg 提交者: GitHub

fix a bug of slice by none index (#34877)

上级 fc6b4a50
...@@ -921,6 +921,29 @@ void BindImperative(py::module *m_ptr) { ...@@ -921,6 +921,29 @@ void BindImperative(py::module *m_ptr) {
axis -= len; axis -= len;
} }
// Deal with cases that there are more than one
// prefix none index, For example:
// [None, None, :, :, None]
// the none_axes int the return of ParseIndexingSlice is:
// [0, 0, 2 ]
// according to the interface of "unsqueeze2",
// we should convert it to:
// [0, 0, 4 ]
int prefix_zero_cnt = 0;
for (const auto &axis : none_axes) {
if (axis == 0) {
prefix_zero_cnt++;
} else {
break;
}
}
if (prefix_zero_cnt > 0) {
int none_axes_num = static_cast<int>(none_axes.size());
for (int i = prefix_zero_cnt; i < none_axes_num; ++i) {
none_axes[i] += prefix_zero_cnt;
}
}
imperative::NameVarBaseMap ins = {{"X", {out}}}; imperative::NameVarBaseMap ins = {{"X", {out}}};
framework::AttributeMap attrs = {{"axes", none_axes}}; framework::AttributeMap attrs = {{"axes", none_axes}};
auto new_out = std::shared_ptr<imperative::VarBase>( auto new_out = std::shared_ptr<imperative::VarBase>(
......
...@@ -711,6 +711,7 @@ class TestVarBase(unittest.TestCase): ...@@ -711,6 +711,7 @@ class TestVarBase(unittest.TestCase):
var_tensor[None, 2, None, 1].numpy(), var_tensor[None, 2, None, 1].numpy(),
var_tensor[None].numpy(), var_tensor[None].numpy(),
var_tensor[0, 0, None, 0, 0, None].numpy(), var_tensor[0, 0, None, 0, 0, None].numpy(),
var_tensor[None, None, 0, ..., None].numpy(),
var_tensor[0, 1:10:2, None, None, ...].numpy(), var_tensor[0, 1:10:2, None, None, ...].numpy(),
] ]
...@@ -724,11 +725,13 @@ class TestVarBase(unittest.TestCase): ...@@ -724,11 +725,13 @@ class TestVarBase(unittest.TestCase):
self.assertTrue(np.array_equal(var[7], np_value[None])) self.assertTrue(np.array_equal(var[7], np_value[None]))
self.assertTrue( self.assertTrue(
np.array_equal(var[8], np_value[0, 0, None, 0, 0, None])) np.array_equal(var[8], np_value[0, 0, None, 0, 0, None]))
self.assertTrue(
np.array_equal(var[9], np_value[None, None, 0, ..., None]))
# TODO(zyfncg) there is a bug of dimensions when slice step > 1 and # TODO(zyfncg) there is a bug of dimensions when slice step > 1 and
# indexs has int type # indexs has int type
# self.assertTrue( # self.assertTrue(
# np.array_equal(var[9], np_value[0, 1:10:2, None, None, ...])) # np.array_equal(var[10], np_value[0, 1:10:2, None, None, ...]))
def _test_for_var(self): def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32') np_value = np.random.random((30, 100, 100)).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册