diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 2e4b52c282d56769ffecd9bf382fb2c9bb0deea2..946027a22f88384a2bc968b8595ee1ed416a6439 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -19,6 +19,7 @@ import numpy as np from op_test import OpTest import paddle import paddle.fluid as fluid +from paddle.framework import core def gather_numpy(x, index, axis): @@ -298,5 +299,13 @@ class TestGathertError(unittest.TestCase): self.assertRaises(TypeError, test_index_type) +class TestCheckOutType(unittest.TestCase): + def test_out_type(self): + data = paddle.static.data(shape=[16, 10], dtype='int64', name='x') + index = paddle.static.data(shape=[4], dtype='int64', name='index') + out = paddle.gather(data, index) + self.assertTrue(out.dtype == core.VarDesc.VarType.INT64) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 060f9a1a9190410a872a4d028b623d770a19f738..bdda90315ac9c72c174356f01f8c9c99d2dcd447 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -804,7 +804,7 @@ def gather(x, index, axis=None, name=None): check_type(axis, 'axis', (int), 'gather') helper = LayerHelper('gather', **locals()) - dtype = helper.input_dtype() + dtype = helper.input_dtype('x') out = helper.create_variable_for_type_inference(dtype) helper.append_op( type="gather",