From 791963abe37a07a17084c8cb96833baa9deee54e Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Wed, 7 Jun 2023 15:06:29 +0800 Subject: [PATCH] support some prim ops bf16 dtype (#54399) --- paddle/phi/kernels/cpu/reduce_sum_kernel.cc | 1 + paddle/phi/kernels/reduce_sum_kernel.cc | 1 + test/legacy_test/test_concat_op.py | 12 +++++--- test/legacy_test/test_elementwise_max_op.py | 21 ++++++++++--- test/legacy_test/test_elementwise_pow_op.py | 11 +++++++ test/legacy_test/test_gather_nd_op.py | 33 +++++++++++++++++---- test/legacy_test/test_reduce_op.py | 5 +++- 7 files changed, 70 insertions(+), 14 deletions(-) diff --git a/paddle/phi/kernels/cpu/reduce_sum_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_kernel.cc index 9cbfae1a637..aa690227bda 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_kernel.cc @@ -49,6 +49,7 @@ PD_REGISTER_KERNEL(sum_raw, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index d249d31c3b6..de9688d4e60 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -44,6 +44,7 @@ PD_REGISTER_KERNEL(sum, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int16_t, int, int64_t, diff --git a/test/legacy_test/test_concat_op.py b/test/legacy_test/test_concat_op.py index 5ffe7a85655..1176ba32b20 100644 --- a/test/legacy_test/test_concat_op.py +++ b/test/legacy_test/test_concat_op.py @@ -136,8 +136,8 @@ class TestConcatOp6(TestConcatOp): self.dtype = self.get_dtype() self.python_api = paddle.concat self.public_python_api = paddle.concat - self.enable_cinn = False self.init_test_data() + self.if_enable_cinn() self.lod = [[20, 80]] self.out_lod = [[20, 80, 20, 80, 20, 80]] self.inputs = { @@ -156,6 +156,9 @@ class TestConcatOp6(TestConcatOp): out = np.concatenate((self.x0, self.x1, self.x2), axis=self.actual_axis) self.outputs = {'Out': (out, self.out_lod)} + def if_enable_cinn(self): + pass + def test_check_output(self): self.check_output() @@ -177,7 +180,7 @@ class TestConcatOp7(TestConcatOp): self.python_api = paddle.concat self.public_python_api = paddle.concat self.prim_op_type = "prim" - self.enable_cinn = True + self.if_enable_cinn() self.dtype = self.get_dtype() self.init_test_data() self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} @@ -194,6 +197,9 @@ class TestConcatOp7(TestConcatOp): ) } + def if_enable_cinn(self): + pass + def get_dtype(self): return "float64" @@ -226,7 +232,6 @@ def create_test_AxisTensor(parent): self.op_type = "concat" self.python_api = paddle.concat self.public_python_api = paddle.concat - self.enable_cinn = False self.dtype = self.get_dtype() self.init_test_data() self.inputs = { @@ -286,7 +291,6 @@ def create_test_fp16(parent): self.op_type = "concat" self.python_api = paddle.concat self.public_python_api = paddle.concat - self.enable_cinn = False self.dtype = self.get_dtype() self.init_test_data() self.inputs = { diff --git a/test/legacy_test/test_elementwise_max_op.py b/test/legacy_test/test_elementwise_max_op.py index d3ff2c4f7c8..40aeaa50a0e 100644 --- a/test/legacy_test/test_elementwise_max_op.py +++ b/test/legacy_test/test_elementwise_max_op.py @@ -198,7 +198,6 @@ class TestElementwiseBF16Op(OpTest): self.python_api = paddle.maximum self.public_python_api = paddle.maximum self.prim_op_type = "prim" - self.enable_cinn = False self.dtype = np.uint16 self.inputs = { 'X': convert_float_to_uint16(self.x), @@ -207,6 +206,7 @@ class TestElementwiseBF16Op(OpTest): self.outputs = { 'Out': convert_float_to_uint16(np.maximum(self.x, self.y)) } + self.if_enable_cinn() def test_check_output(self): if hasattr(self, 'attrs'): @@ -214,6 +214,9 @@ class TestElementwiseBF16Op(OpTest): else: self.check_output(check_dygraph=True) + def if_enable_cinn(self): + pass + def test_check_grad_normal(self): if hasattr(self, 'attrs'): # check_prim=False, bfloat16 is not supported in `less_equal` @@ -221,16 +224,26 @@ class TestElementwiseBF16Op(OpTest): ['X', 'Y'], 'Out', numeric_grad_delta=0.05, check_dygraph=False ) else: - self.check_grad(['X', 'Y'], 'Out', numeric_grad_delta=0.05) + self.check_grad( + ['X', 'Y'], 'Out', numeric_grad_delta=0.05, check_prim=True + ) def test_check_grad_ingore_x(self): self.check_grad( - ['Y'], 'Out', numeric_grad_delta=0.05, no_grad_set=set("X") + ['Y'], + 'Out', + numeric_grad_delta=0.05, + no_grad_set=set("X"), + check_prim=True, ) def test_check_grad_ingore_y(self): self.check_grad( - ['X'], 'Out', numeric_grad_delta=0.05, no_grad_set=set('Y') + ['X'], + 'Out', + numeric_grad_delta=0.05, + no_grad_set=set('Y'), + check_prim=True, ) diff --git a/test/legacy_test/test_elementwise_pow_op.py b/test/legacy_test/test_elementwise_pow_op.py index 9eba287b220..d450cc8a606 100644 --- a/test/legacy_test/test_elementwise_pow_op.py +++ b/test/legacy_test/test_elementwise_pow_op.py @@ -19,6 +19,7 @@ from eager_op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci import paddle from paddle import fluid +from paddle.fluid import core def pow_grad(x, y, dout): @@ -270,8 +271,10 @@ class TestElementwisePowOpFP16(OpTest): class TestElementwisePowBF16Op(OpTest): def setUp(self): self.op_type = "elementwise_pow" + self.prim_op_type = "prim" self.dtype = np.uint16 self.python_api = paddle.pow + self.public_python_api = paddle.pow x = np.random.uniform(0, 1, [20, 5]).astype(np.float32) y = np.random.uniform(0, 1, [20, 5]).astype(np.float32) @@ -290,6 +293,14 @@ class TestElementwisePowBF16Op(OpTest): def test_check_grad(self): self.check_grad(['X', 'Y'], 'Out') + if core.is_compiled_with_cuda(): + self.check_grad_with_place( + core.CUDAPlace(0), + ['X', 'Y'], + 'Out', + check_prim=True, + only_check_prim=True, + ) if __name__ == '__main__': diff --git a/test/legacy_test/test_gather_nd_op.py b/test/legacy_test/test_gather_nd_op.py index bc839b0c2db..530ea710226 100644 --- a/test/legacy_test/test_gather_nd_op.py +++ b/test/legacy_test/test_gather_nd_op.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest, convert_float_to_uint16 +from eager_op_test import ( + OpTest, + convert_float_to_uint16, + convert_uint16_to_float, +) import paddle from paddle import fluid @@ -31,6 +35,7 @@ class TestGatherNdOpWithEmptyIndex(OpTest): self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd self.config_dtype() + self.if_enable_cinn() if self.dtype == np.float64: target_dtype = "float64" elif self.dtype == np.float16: @@ -45,6 +50,9 @@ class TestGatherNdOpWithEmptyIndex(OpTest): self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")} self.outputs = {'Out': output} + def if_enable_cinn(self): + pass + def config_dtype(self): self.dtype = np.float64 @@ -85,6 +93,7 @@ class TestGatherNdOpWithIndex1(OpTest): self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd self.config_dtype() + self.if_enable_cinn() if self.dtype == np.float64: target_dtype = "float64" elif self.dtype == np.float16: @@ -100,6 +109,9 @@ class TestGatherNdOpWithIndex1(OpTest): self.inputs = {'X': xnp, 'Index': index} self.outputs = {'Out': output} + def if_enable_cinn(self): + pass + def config_dtype(self): self.dtype = np.float64 @@ -189,7 +201,9 @@ class TestGatherNdOpWithLowIndexBF16(TestGatherNdOpWithLowIndex): def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'Out', check_prim=True, numeric_grad_delta=0.5 + ) class TestGatherNdOpIndex1(OpTest): @@ -208,6 +222,8 @@ class TestGatherNdOpIndex1(OpTest): else: target_dtype = "float32" xnp = np.random.uniform(0, 100, (10, 10)).astype(target_dtype) + if self.dtype == np.uint16: + xnp = convert_uint16_to_float(convert_float_to_uint16(xnp)) index = np.array([1, 2]).astype("int32") output = xnp[tuple(index.T)] if self.dtype == np.uint16: @@ -215,6 +231,9 @@ class TestGatherNdOpIndex1(OpTest): output = convert_float_to_uint16(output) self.inputs = {'X': xnp, 'Index': index} self.outputs = {'Out': output} + self.if_enable_cinn() + + def if_enable_cinn(self): # the outputs are 0D-tensor, CINN not support self.enable_cinn = False @@ -225,7 +244,7 @@ class TestGatherNdOpIndex1(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, numeric_grad_delta=0.05) class TestGatherNdOpIndex1FP16(TestGatherNdOpIndex1): @@ -248,7 +267,9 @@ class TestGatherNdOpIndex1BF16(TestGatherNdOpIndex1): def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'Out', check_prim=True, numeric_grad_delta=0.5 + ) class TestGatherNdOpWithSameIndexAsX(OpTest): @@ -304,7 +325,9 @@ class TestGatherNdOpWithSameIndexAsXBF16(TestGatherNdOpWithSameIndexAsX): def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'Out', check_prim=True, numeric_grad_delta=0.5 + ) class TestGatherNdOpWithHighRankSame(OpTest): diff --git a/test/legacy_test/test_reduce_op.py b/test/legacy_test/test_reduce_op.py index 0fabea0ad18..95d5fb5ceb2 100644 --- a/test/legacy_test/test_reduce_op.py +++ b/test/legacy_test/test_reduce_op.py @@ -277,13 +277,16 @@ class TestMaxOp_ZeroDim(OpTest): self.prim_op_type = "prim" self.python_api = paddle.max self.public_python_api = paddle.max - self.enable_cinn = False + self.if_enable_cinn() self.inputs = {'X': np.random.random([]).astype("float64")} self.attrs = {'dim': []} self.outputs = { 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) } + def if_enable_cinn(self): + self.enable_cinn = False + def test_check_output(self): self.check_output() -- GitLab