未验证 提交 d12c3636 编写于 作者: T TTerror 提交者: GitHub

fix gather_nd, *test=kunlun (#39283)

上级 9ba3f429
......@@ -47,8 +47,12 @@ class GatherNdXPUKernel : public framework::OpKernel<T> {
auto x_shape = paddle::framework::vectorize<int>(x->dims());
auto index_shape = paddle::framework::vectorize<int>(index->dims());
if (index_shape.size() == 1) {
index_shape.insert(index_shape.begin(), 1);
}
xpu::VectorParam<int> x_vec = {x_shape.data(),
static_cast<int>(x_shape.size()), nullptr};
auto &dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = XPU_SUCCESS;
......
......@@ -18,251 +18,140 @@ import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
import paddle
def gather_nd_grad(x, index):
dout_shape = index.shape[:-1] + x.shape[index.shape[-1]:]
numel = 1
for i in dout_shape:
numel = numel * i
dout = np.full(dout_shape, 1. / numel)
dx = np.full_like(x, 0)
index = tuple(index.reshape(-1, index.shape[-1]).T)
np.add.at(dx, index, dout)
return dx
def test_class1(op_type, typename):
class TestGatherNdOpWithEmptyIndex(XPUOpTest):
#Index has empty element, which means copy entire tensor
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.random((5, 20)).astype(typename)
self.inputs = {
'X': xnp,
'Index': np.array([[], []]).astype("int32")
}
self.outputs = {
'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_1".format(op_type, typename)
TestGatherNdOpWithEmptyIndex.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithEmptyIndex
def test_class2(op_type, typename):
class TestGatherNdOpWithIndex1(OpTest):
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.random((5, 20)).astype(typename)
self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_2".format(op_type, typename)
TestGatherNdOpWithIndex1.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithIndex1
def test_class3(op_type, typename):
class TestGatherNdOpWithLowIndex(OpTest):
#Index has low rank, X has high rank
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.uniform(0, 100, (10, 10)).astype(typename)
index = np.array([[1], [2]]).astype("int64")
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]}
self.x_grad = gather_nd_grad(xnp, index)
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
import paddle
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
cls_name = "{0}_{1}_3".format(op_type, typename)
TestGatherNdOpWithLowIndex.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithLowIndex
paddle.enable_static()
def test_class4(op_type, typename):
class TestGatherNdOpIndex1(OpTest):
#Index has low rank, X has high rank
class XPUTestGatherNd(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'gather_nd'
class XPUTestGatherNdBase(XPUOpTest):
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.uniform(0, 100, (10, 10)).astype(typename)
index = np.array([1, 2]).astype("int64")
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_4".format(op_type, typename)
TestGatherNdOpIndex1.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpIndex1
def test_class5(op_type, typename):
class TestGatherNdOpWithSameIndexAsX(OpTest):
#Index has same rank as X's rank
def setUp(self):
self.set_xpu()
self.dtype = self.in_type
self.__class__.no_need_check_grad = True
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.uniform(0, 100, (10, 10)).astype(typename)
index = np.array([[1, 1], [2, 1]]).astype("int64")
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22]
self.init_data()
def set_xpu(self):
self.__class__.use_xpu = True
self.inputs = {'X': self.xnp, 'Index': self.inp}
self.outputs = {'Out': self.output, }
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_5".format(op_type, typename)
TestGatherNdOpWithSameIndexAsX.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithSameIndexAsX
def test_class6(op_type, typename):
class TestGatherNdOpWithHighRankSame(OpTest):
#Both Index and X have high rank, and Rank(Index) = Rank(X)
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([[], []]).astype("int32")
self.output = np.vstack(
(self.xnp[np.newaxis, :], self.xnp[np.newaxis, :]))
class XPUTestGatherNdOpWithEmptyIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([[], []]).astype("int32")
self.output = np.vstack(
(self.xnp[np.newaxis, :], self.xnp[np.newaxis, :]))
class XPUTestGatherNdOpWithEmptyIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([[], []]).astype("int64")
self.output = np.vstack(
(self.xnp[np.newaxis, :], self.xnp[np.newaxis, :]))
class XPUTestGatherNdOpWithIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([1]).astype("int32")
self.output = self.xnp[self.inp]
class XPUTestGatherNdOpWithIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([1]).astype("int64")
self.output = self.xnp[self.inp]
class XPUTestGatherNdOpWithLowIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([[1], [2]]).astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpWithLowIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([1, 2]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpWithHighRankSame1(XPUTestGatherNdBase):
def init_data(self):
shape = (5, 2, 3, 1, 10)
xnp = np.random.rand(*shape).astype(typename)
index = np.vstack([np.random.randint(
0, s, size=2) for s in shape]).T
self.inputs = {'X': xnp, 'Index': index.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)]}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_6".format(op_type, typename)
TestGatherNdOpWithHighRankSame.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithHighRankSame
self.xnp = np.random.rand(*shape).astype(self.in_type)
self.inp = np.vstack(
[np.random.randint(
0, s, size=2) for s in shape]).T.astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
def test_class7(op_type, typename):
class TestGatherNdOpWithHighRankDiff(OpTest):
#Both Index and X have high rank, Rank(Index) < Rank(X)
class XPUTestGatherNdOpWithHighRankSame2(XPUTestGatherNdBase):
def init_data(self):
shape = (5, 2, 3, 1, 10)
self.xnp = np.random.rand(*shape).astype(self.in_type)
self.inp = np.vstack(
[np.random.randint(
0, s, size=2) for s in shape]).T.astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
class XPUTestGatherNdOpWithHighRankDiff1(XPUTestGatherNdBase):
def init_data(self):
shape = (2, 3, 4, 1, 10)
xnp = np.random.rand(*shape).astype(typename)
index = np.vstack(
self.xnp = np.random.rand(*shape).astype(self.in_type)
self.inp = np.vstack(
[np.random.randint(
0, s, size=200) for s in shape]).T
index_re = index.reshape([20, 5, 2, 5])
self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
0, s, size=200) for s in shape]).T.astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_7".format(op_type, typename)
TestGatherNdOpWithHighRankDiff.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithHighRankDiff
class TestGatherNdAPI(unittest.TestCase):
def test_imperative(self):
paddle.disable_static()
input_1 = np.array([[1, 2], [3, 4], [5, 6]])
index_1 = np.array([[1]])
input = fluid.dygraph.to_variable(input_1)
index = fluid.dygraph.to_variable(index_1)
output = paddle.fluid.layers.gather(input, index)
output_np = output.numpy()
expected_output = np.array([3, 4])
self.assertTrue(np.allclose(output_np, expected_output))
paddle.enable_static()
for _typename in {'float32', 'int', 'int64'}:
test_class1('gather_nd', _typename)
test_class2('gather_nd', _typename)
test_class3('gather_nd', _typename)
test_class4('gather_nd', _typename)
test_class5('gather_nd', _typename)
test_class6('gather_nd', _typename)
test_class7('gather_nd', _typename)
class XPUTestGatherNdOpWithHighRankDiff2(XPUTestGatherNdBase):
def init_data(self):
shape = (2, 3, 4, 1, 10)
self.xnp = np.random.rand(*shape).astype(self.in_type)
self.inp = np.vstack(
[np.random.randint(
0, s, size=200) for s in shape]).T.astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpWithSameIndexAsX1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([[1, 1], [2, 1]]).astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpWithSameIndexAsX2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([[1, 1], [2, 1]]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([1, 2]).astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([1, 2]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
support_types = get_xpu_op_support_types('gather_nd')
for stype in support_types:
create_test_class(globals(), XPUTestGatherNd, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册