未验证 提交 ff062a43 编写于 作者: G Guoxia Wang 提交者: GitHub

fix output dtype for paddle.sum (#34313)

* support bool dtype for paddle.sum
上级 a842828a
develop Ligoml-patch-1 ZHUI-patch-1 add_some_yaml_config addfile ascendrelease cherry_undefined_var delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_retry_ci delete_fix_undefined_var delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var incubate/infrt inplace_addto make_flag_adding_easier move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc preln_ernie prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment support_weight_transpose zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0
无相关合并请求
......@@ -184,7 +184,7 @@ def test_optim_break_in_while(x):
class TestContinueInFor(unittest.TestCase):
def setUp(self):
self.input = np.zeros((1)).astype('int32')
self.input = np.zeros((1)).astype('int64')
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_dygraph_func()
......
......@@ -748,37 +748,6 @@ class TestReduceSumOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.reduce_sum, x2)
class API_TestSumOpError(unittest.TestCase):
def test_errors(self):
def test_dtype1():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="float32")
self.assertRaises(ValueError, test_dtype1)
def test_dtype2():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int64")
paddle.sum(data, dtype="int32")
self.assertRaises(ValueError, test_dtype2)
def test_dtype3():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.sum(data, dtype="int32")
self.assertRaises(ValueError, test_dtype3)
def test_type():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="int32")
paddle.sum(data, dtype="bool")
self.assertRaises(TypeError, test_type)
class API_TestSumOp(unittest.TestCase):
def run_static(self,
shape,
......@@ -805,14 +774,26 @@ class API_TestSumOp(unittest.TestCase):
shape = [10, 10]
axis = 1
self.run_static(shape, "bool", axis, attr_dtype=None)
self.run_static(shape, "bool", axis, attr_dtype="int32")
self.run_static(shape, "bool", axis, attr_dtype="int64")
self.run_static(shape, "int32", axis, attr_dtype=None)
self.run_static(shape, "int32", axis, attr_dtype="int32")
self.run_static(shape, "int32", axis, attr_dtype="int64")
self.run_static(shape, "int64", axis, attr_dtype=None)
self.run_static(shape, "int64", axis, attr_dtype="int64")
self.run_static(shape, "int64", axis, attr_dtype="int32")
self.run_static(shape, "float32", axis, attr_dtype=None)
self.run_static(shape, "float32", axis, attr_dtype="float32")
self.run_static(shape, "float32", axis, attr_dtype="float64")
self.run_static(shape, "float64", axis, attr_dtype=None)
self.run_static(shape, "float64", axis, attr_dtype="float32")
self.run_static(shape, "float64", axis, attr_dtype="float64")
shape = [5, 5, 5]
self.run_static(shape, "int32", (0, 1), attr_dtype="int32")
self.run_static(
......
......@@ -716,13 +716,15 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
else:
reduce_all_flag = False
dtype_flag = False
if dtype is not None:
if dtype in ['float64', 'int64']:
if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
(convert_dtype(x.dtype) == "int32" and dtype == "int64"):
dtype_flag = True
def get_dtype(x, dtype):
if dtype is not None:
return (True, dtype)
src_type = convert_dtype(x.dtype)
if src_type in ['bool','int32', 'int64']:
return (True, 'int64')
return (False, src_type)
dtype_flag, dtype = get_dtype(x, dtype)
if in_dygraph_mode():
axis = axis if axis != None and axis != [] else [0]
if dtype_flag:
......@@ -740,27 +742,17 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'reduce_all': reduce_all_flag
}
if dtype is not None:
if dtype in ['float64', 'int64']:
if (convert_dtype(x.dtype) == "float32" and dtype == "float64") or \
(convert_dtype(x.dtype) == "int32" and dtype == "int64"):
attrs.update({
'in_dtype': x.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
if dtype_flag:
attrs.update({
'in_dtype': x.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'sum')
if dtype is not None:
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'sum')
x_dtype = convert_dtype(x.dtype)
if (x_dtype == "float64" and dtype in ["float32", "int32"]) or \
(x_dtype == "int64" and dtype == "int32"):
raise ValueError("The input(x)'s dtype is {} but the attr(dtype) of sum is {}, "
"which may cause data type overflows. Please reset attr(dtype) of sum."
.format(x_dtype, dtype))
x, 'x', ['bool', 'float16', 'float32', 'float64',
'int32', 'int64', 'complex64', 'complex128',
u'bool', u'float16', u'float32', u'float64',
u'int32', u'int64', u'complex64', u'complex128'], 'sum')
check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部