test_gather_op.py 926 字节
Newer Older
Z
zchen0211 已提交
1
import unittest
Z
zchen0211 已提交
2 3
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
Z
zchen0211 已提交
4 5 6 7 8 9 10 11 12 13
import numpy
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator


class TestGatherOp(unittest.TestCase):
    __metaclass__ = OpTestMeta

    def setUp(self):
        self.type = "gather"
Z
zchen0211 已提交
14
        xnp = numpy.random.random((10, 20)).astype("float32")
Z
zchen0211 已提交
15
        self.inputs = {
Z
zchen0211 已提交
16 17
            'X': xnp,
            'Index': numpy.array([1, 3, 5]).astype("int32")
Z
zchen0211 已提交
18
        }
Z
zchen0211 已提交
19 20 21 22 23 24 25 26 27
        self.outputs = {'Out': self.inputs['X'][self.inputs['Index']]}


class TestGatherGradOp(GradientChecker):
    def test_gather_grad(self):
        op = create_op("gather")
        xnp = numpy.random.random((10, 20)).astype("float32")
        inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")}
        self.check_grad(op, inputs, set("X"), "Out")
Z
zchen0211 已提交
28 29 30 31


if __name__ == "__main__":
    unittest.main()