未验证 提交 8040fa2b 编写于 作者: A Aurelius84 提交者: GitHub

Fix output dtype inconsistent with input (#28649)

* fix output dtyp inconsistent with input

* refine code
上级 57dab959
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.framework import core
def gather_numpy(x, index, axis):
......@@ -298,5 +299,13 @@ class TestGathertError(unittest.TestCase):
self.assertRaises(TypeError, test_index_type)
class TestCheckOutType(unittest.TestCase):
def test_out_type(self):
data = paddle.static.data(shape=[16, 10], dtype='int64', name='x')
index = paddle.static.data(shape=[4], dtype='int64', name='index')
out = paddle.gather(data, index)
self.assertTrue(out.dtype == core.VarDesc.VarType.INT64)
if __name__ == "__main__":
unittest.main()
......@@ -804,7 +804,7 @@ def gather(x, index, axis=None, name=None):
check_type(axis, 'axis', (int), 'gather')
helper = LayerHelper('gather', **locals())
dtype = helper.input_dtype()
dtype = helper.input_dtype('x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="gather",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册