未验证 提交 c145fd1e 编写于 作者: H houj04 提交者: GitHub

fix int8 support for full kernel (#52194)

* fix int8 support for full kernel

* fix ut.
上级 4118ab89
...@@ -116,6 +116,7 @@ PD_REGISTER_KERNEL(full, ...@@ -116,6 +116,7 @@ PD_REGISTER_KERNEL(full,
phi::FullKernel, phi::FullKernel,
float, float,
double, double,
int8_t,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
......
...@@ -75,12 +75,6 @@ class ApiOnesZerosError(unittest.TestCase): ...@@ -75,12 +75,6 @@ class ApiOnesZerosError(unittest.TestCase):
self.assertRaises(TypeError, test_error3) self.assertRaises(TypeError, test_error3)
def test_error4():
with paddle.static.program_guard(paddle.static.Program()):
ones = paddle.ones(shape=[10], dtype="int8")
self.assertRaises(TypeError, test_error4)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -21,14 +21,6 @@ from paddle import fluid ...@@ -21,14 +21,6 @@ from paddle import fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
class TestZerosOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
shape = [4]
dtype = 'int8'
self.assertRaises(TypeError, fluid.layers.zeros, shape, dtype)
class ApiZerosTest(unittest.TestCase): class ApiZerosTest(unittest.TestCase):
def test_out(self): def test_out(self):
with program_guard(Program()): with program_guard(Program()):
...@@ -46,11 +38,11 @@ class ApiZerosTest(unittest.TestCase): ...@@ -46,11 +38,11 @@ class ApiZerosTest(unittest.TestCase):
expected_result = np.zeros(10, dtype='int64') expected_result = np.zeros(10, dtype='int64')
self.assertEqual((result == expected_result).all(), True) self.assertEqual((result == expected_result).all(), True)
with program_guard(Program()): with program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype='int64') zeros = paddle.zeros(shape=[10], dtype='int8')
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
(result,) = exe.run(fetch_list=[zeros]) (result,) = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype='int64') expected_result = np.zeros(10, dtype='int8')
self.assertEqual((result == expected_result).all(), True) self.assertEqual((result == expected_result).all(), True)
with program_guard(Program()): with program_guard(Program()):
out_np = np.zeros(shape=1, dtype='float32') out_np = np.zeros(shape=1, dtype='float32')
...@@ -78,12 +70,6 @@ class ApiZerosError(unittest.TestCase): ...@@ -78,12 +70,6 @@ class ApiZerosError(unittest.TestCase):
self.assertRaises(TypeError, test_error1) self.assertRaises(TypeError, test_error1)
def test_error2():
with paddle.static.program_guard(fluid.Program()):
ones = fluid.layers.zeros(shape=[10], dtype='int8')
self.assertRaises(TypeError, test_error2)
def test_shape_errors(self): def test_shape_errors(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
try: try:
......
...@@ -881,7 +881,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -881,7 +881,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs = {'force_cpu': force_cpu} attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype) dtype = convert_dtype(dtype)
if not isinstance(value, Variable): if not isinstance(value, Variable):
if dtype in ['uint8', 'int16', 'int32', 'int64']: if dtype in ['int8', 'uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value)) attrs['str_value'] = str(int(value))
attrs['value'] = int(value) attrs['value'] = int(value)
else: else:
...@@ -904,6 +904,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -904,6 +904,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
'float16', 'float16',
'float32', 'float32',
'float64', 'float64',
'int8',
'uint8', 'uint8',
'int16', 'int16',
'int32', 'int32',
...@@ -2045,6 +2046,7 @@ def assign(x, output=None): ...@@ -2045,6 +2046,7 @@ def assign(x, output=None):
'int32', 'int32',
'int64', 'int64',
'uint8', 'uint8',
'int8',
'bool', 'bool',
], ],
'assign', 'assign',
...@@ -2218,6 +2220,7 @@ def _memcpy(input, place=None, output=None): ...@@ -2218,6 +2220,7 @@ def _memcpy(input, place=None, output=None):
'int32', 'int32',
'int64', 'int64',
'uint8', 'uint8',
'int8',
'bool', 'bool',
], ],
'memcpy', 'memcpy',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册