未验证 提交 c145fd1e 编写于 作者: H houj04 提交者: GitHub

fix int8 support for full kernel (#52194)

* fix int8 support for full kernel

* fix ut.
上级 4118ab89
......@@ -116,6 +116,7 @@ PD_REGISTER_KERNEL(full,
phi::FullKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
......
......@@ -75,12 +75,6 @@ class ApiOnesZerosError(unittest.TestCase):
self.assertRaises(TypeError, test_error3)
def test_error4():
with paddle.static.program_guard(paddle.static.Program()):
ones = paddle.ones(shape=[10], dtype="int8")
self.assertRaises(TypeError, test_error4)
if __name__ == "__main__":
unittest.main()
......@@ -21,14 +21,6 @@ from paddle import fluid
from paddle.fluid import Program, program_guard
class TestZerosOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
shape = [4]
dtype = 'int8'
self.assertRaises(TypeError, fluid.layers.zeros, shape, dtype)
class ApiZerosTest(unittest.TestCase):
def test_out(self):
with program_guard(Program()):
......@@ -46,11 +38,11 @@ class ApiZerosTest(unittest.TestCase):
expected_result = np.zeros(10, dtype='int64')
self.assertEqual((result == expected_result).all(), True)
with program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype='int64')
zeros = paddle.zeros(shape=[10], dtype='int8')
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
(result,) = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype='int64')
expected_result = np.zeros(10, dtype='int8')
self.assertEqual((result == expected_result).all(), True)
with program_guard(Program()):
out_np = np.zeros(shape=1, dtype='float32')
......@@ -78,12 +70,6 @@ class ApiZerosError(unittest.TestCase):
self.assertRaises(TypeError, test_error1)
def test_error2():
with paddle.static.program_guard(fluid.Program()):
ones = fluid.layers.zeros(shape=[10], dtype='int8')
self.assertRaises(TypeError, test_error2)
def test_shape_errors(self):
with fluid.dygraph.guard():
try:
......
......@@ -881,7 +881,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable):
if dtype in ['uint8', 'int16', 'int32', 'int64']:
if dtype in ['int8', 'uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value))
attrs['value'] = int(value)
else:
......@@ -904,6 +904,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
'float16',
'float32',
'float64',
'int8',
'uint8',
'int16',
'int32',
......@@ -2045,6 +2046,7 @@ def assign(x, output=None):
'int32',
'int64',
'uint8',
'int8',
'bool',
],
'assign',
......@@ -2218,6 +2220,7 @@ def _memcpy(input, place=None, output=None):
'int32',
'int64',
'uint8',
'int8',
'bool',
],
'memcpy',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册