From 8040fa2bca72224c66ba6700dcc7e8ae79ea0554 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 17 Nov 2020 11:43:29 +0800 Subject: [PATCH] Fix output dtype inconsistent with input (#28649) * fix output dtyp inconsistent with input * refine code --- python/paddle/fluid/tests/unittests/test_gather_op.py | 9 +++++++++ python/paddle/tensor/manipulation.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 2e4b52c282d..946027a22f8 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -19,6 +19,7 @@ import numpy as np from op_test import OpTest import paddle import paddle.fluid as fluid +from paddle.framework import core def gather_numpy(x, index, axis): @@ -298,5 +299,13 @@ class TestGathertError(unittest.TestCase): self.assertRaises(TypeError, test_index_type) +class TestCheckOutType(unittest.TestCase): + def test_out_type(self): + data = paddle.static.data(shape=[16, 10], dtype='int64', name='x') + index = paddle.static.data(shape=[4], dtype='int64', name='index') + out = paddle.gather(data, index) + self.assertTrue(out.dtype == core.VarDesc.VarType.INT64) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 060f9a1a919..bdda90315ac 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -804,7 +804,7 @@ def gather(x, index, axis=None, name=None): check_type(axis, 'axis', (int), 'gather') helper = LayerHelper('gather', **locals()) - dtype = helper.input_dtype() + dtype = helper.input_dtype('x') out = helper.create_variable_for_type_inference(dtype) helper.append_op( type="gather", -- GitLab