未验证 提交 3be6791f 编写于 作者: Y YuhangLi 提交者: GitHub

[AMP OP&Test] Mean fp/bf 16 support (#51114)

* mean fp16

* fp16

* [AMP OP&Test] mean append bf/fp16

* means append more bf16 uts

* format class name

* fix ci

* fix for windows

* fix issue

* fix redundancy

* fix redund

* fix elewise_max ut bf16 numeric delta

* remove use func
上级 973dab86
......@@ -66,4 +66,5 @@ PD_REGISTER_KERNEL(mean_grad,
bool,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -45,6 +45,7 @@ PD_REGISTER_KERNEL(mean_raw,
float,
double,
bool,
phi::dtype::bfloat16,
float16,
int,
int64_t) {}
......
......@@ -44,7 +44,8 @@ PD_REGISTER_KERNEL(mean,
bool,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
......
......@@ -17,7 +17,7 @@ import unittest
import gradient_checker
import numpy as np
from decorator_helper import prog_scope
from op_test import OpTest, OpTestTool
from op_test import OpTest, OpTestTool, convert_float_to_uint16
from test_sum_op import TestReduceOPTensorAxisBase
import paddle
......@@ -141,14 +141,11 @@ def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):
return np.mean(x, axis=axis, keepdims=keepdim)
def ref_reduce_mean_grad(x, axis, dtype, reduce_all):
if reduce_all:
axis = list(range(x.ndim))
shape = [x.shape[i] for i in axis]
return (1.0 / np.prod(shape) * np.ones(shape)).astype(dtype)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_float16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA",
)
class TestReduceMeanOp(OpTest):
def setUp(self):
self.op_type = 'reduce_mean'
......@@ -173,9 +170,6 @@ class TestReduceMeanOp(OpTest):
'reduce_all': self.reduce_all,
}
if self.dtype == 'float16':
self.__class__.no_need_check_grad = True
def set_attrs(self):
pass
......@@ -183,8 +177,6 @@ class TestReduceMeanOp(OpTest):
if self.dtype != 'float16':
self.check_output(check_eager=True)
else:
if not core.is_compiled_with_cuda():
return
place = paddle.CUDAPlace(0)
self.check_output_with_place(place=place)
......@@ -192,24 +184,53 @@ class TestReduceMeanOp(OpTest):
if self.dtype != 'float16':
self.check_grad(['X'], ['Out'], check_eager=True)
else:
if not core.is_compiled_with_cuda():
return
place = paddle.CUDAPlace(0)
if core.is_float16_supported(place):
return
with fluid.dygraph.guard(place=place):
x = paddle.tensor(self.inputs['X'])
y = paddle.mean(
x, axis=self.attrs['dim'], keepdim=self.attrs['keep_dim']
)
dx = paddle.grad(y, x)[0].numpy()
dx_expected = ref_reduce_mean_grad(
self.inputs['X'],
self.attrs['dim'],
self.dtype,
self.attrs['reduce_all'],
)
np.testing.assert_array_equal(dx, dx_expected)
self.check_grad_with_place(
place, ['X'], ['Out'], numeric_grad_delta=0.5
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestReduceMeanBF16Op(OpTest):
def setUp(self):
self.op_type = 'reduce_mean'
self.python_api = reduce_mean_wrapper
self.dtype = np.uint16
self.shape = [2, 3, 4, 5]
self.axis = [0]
self.keepdim = False
self.set_attrs()
np.random.seed(10)
x_np = np.random.uniform(-1, 1, self.shape).astype(np.float32)
if not hasattr(self, "reduce_all"):
self.reduce_all = (not self.axis) or len(self.axis) == len(x_np)
out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': convert_float_to_uint16(x_np)}
self.outputs = {'Out': convert_float_to_uint16(out_np)}
self.attrs = {
'dim': self.axis,
'keep_dim': self.keepdim,
'reduce_all': self.reduce_all,
}
def set_attrs(self):
pass
def test_check_output(self):
place = paddle.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = paddle.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], ['Out'], numeric_grad_delta=0.05
)
class TestReduceMeanOpDefaultAttrs(TestReduceMeanOp):
......@@ -251,6 +272,11 @@ class TestReduceMeanOpShape6D(TestReduceMeanOp):
self.shape = [2, 3, 4, 5, 6, 7]
class TestReduceMeanOpShape6DBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.shape = [2, 3, 4, 5, 6, 7]
class TestReduceMeanOpShape6DFP16(TestReduceMeanOp):
def set_attrs(self):
self.shape = [2, 3, 4, 5, 6, 7]
......@@ -268,6 +294,11 @@ class TestReduceMeanOpAxisAllFP16(TestReduceMeanOp):
self.dtype = 'float16'
class TestReduceMeanOpAxisAllBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
class TestReduceMeanOpAxisTuple(TestReduceMeanOp):
def set_attrs(self):
self.axis = (0, 1, 2)
......@@ -279,6 +310,11 @@ class TestReduceMeanOpAxisTupleFP16(TestReduceMeanOp):
self.dtype = 'float16'
class TestReduceMeanOpAxisTupleBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.axis = (0, 1, 2)
class TestReduceMeanOpAxisNegative(TestReduceMeanOp):
def set_attrs(self):
self.axis = [-2, -1]
......@@ -290,6 +326,11 @@ class TestReduceMeanOpAxisNegativeFP16(TestReduceMeanOp):
self.dtype = 'float16'
class TestReduceMeanOpAxisNegativeBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.axis = [-2, -1]
class TestReduceMeanOpKeepdimTrue1(TestReduceMeanOp):
def set_attrs(self):
self.keepdim = True
......@@ -301,6 +342,11 @@ class TestReduceMeanOpKeepdimTrue1FP16(TestReduceMeanOp):
self.dtype = 'float16'
class TestReduceMeanOpKeepdimTrue1BF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.keepdim = True
class TestReduceMeanOpKeepdimTrue2(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
......@@ -314,6 +360,12 @@ class TestReduceMeanOpKeepdimTrue2FP16(TestReduceMeanOp):
self.dtype = 'float16'
class TestReduceMeanOpKeepdimTrue2BF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.keepdim = True
class TestReduceMeanOpReduceAllTrue(TestReduceMeanOp):
def set_attrs(self):
self.reduce_all = True
......@@ -325,6 +377,11 @@ class TestReduceMeanOpReduceAllTrueFP16(TestReduceMeanOp):
self.dtype = 'float16'
class TestReduceMeanOpReduceAllTrueBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.reduce_all = True
class TestMeanAPI(unittest.TestCase):
# test paddle.tensor.stat.mean
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册