未验证 提交 f86d0be7 编写于 作者: Z zengshao0622 提交者: GitHub

[AMP OP&Test] pad3d add unittests of fp16 and bf16 (#51015)

* pad3d add unittests of fp16 and bf16

* pad3d add unittests of fp16 and bf16

* fix cuda place

* fix random to uniform

* fix class name

* fix fp16 max relative error to 1.5e-3

* add dytpe register for onednn

* add pad uint16 check of common.py

* remove check_eager

* test_check_grad --> test_check_grad_normal
上级 7f86c1dc
......@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/pad3d_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/onednn/pad_kernel_impl.h"
......@@ -31,4 +32,10 @@ void Pad3dKernel(const Context& dev_ctx,
}
} // namespace phi
PD_REGISTER_KERNEL(pad3d, OneDNN, ONEDNN, phi::Pad3dKernel, float) {}
PD_REGISTER_KERNEL(pad3d,
OneDNN,
ONEDNN,
phi::Pad3dKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float) {}
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.nn.functional as F
......@@ -34,9 +34,14 @@ class TestPad3dOp(OpTest):
paddle.enable_static()
self.value = 0.0
self.initTestCase()
self.dtype = self.get_dtype()
self.op_type = "pad3d"
self.python_api = paddle.nn.functional.pad
self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.inputs = {
'X': np.random.uniform(-1.0, 1.0, self.shape).astype("float32")
if self.dtype == np.uint16
else np.random.uniform(-1.0, 1.0, self.shape).astype(self.dtype)
}
self.attrs = {}
if self.variable_paddings:
self.attrs['paddings'] = []
......@@ -81,12 +86,19 @@ class TestPad3dOp(OpTest):
out = np.pad(self.inputs['X'], paddings, mode="wrap")
self.outputs = {'Out': out}
if self.dtype == np.uint16:
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
def get_dtype(self):
return np.float64
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.paddings = [0, 0, 0, 0, 0, 0]
......@@ -190,6 +202,80 @@ class TestCase10(TestPad3dOp):
self.variable_paddings = True
# ----------------Pad3d Fp16----------------
def create_test_fp16(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestPad3dFp16(parent):
def get_dtype(self):
return np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=1.5e-3)
cls_name = "{0}_{1}".format(parent.__name__, "FP16OP")
TestPad3dFp16.__name__ = cls_name
globals()[cls_name] = TestPad3dFp16
create_test_fp16(TestCase1)
create_test_fp16(TestCase2)
create_test_fp16(TestCase3)
create_test_fp16(TestCase4)
create_test_fp16(TestCase5)
create_test_fp16(TestCase6)
create_test_fp16(TestCase7)
create_test_fp16(TestCase8)
create_test_fp16(TestCase9)
create_test_fp16(TestCase10)
# ----------------Pad3d Bf16----------------
def create_test_bf16(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestPad3dBf16(parent):
def get_dtype(self):
return np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-2)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=1e-2
)
cls_name = "{0}_{1}".format(parent.__name__, "BF16OP")
TestPad3dBf16.__name__ = cls_name
globals()[cls_name] = TestPad3dBf16
create_test_bf16(TestCase1)
create_test_bf16(TestCase2)
create_test_bf16(TestCase3)
create_test_bf16(TestCase4)
create_test_bf16(TestCase5)
create_test_bf16(TestCase6)
create_test_bf16(TestCase7)
create_test_bf16(TestCase8)
create_test_bf16(TestCase9)
create_test_bf16(TestCase10)
class TestPadAPI(unittest.TestCase):
def setUp(self):
self.places = [paddle.CPUPlace()]
......
......@@ -1598,6 +1598,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):
'int64',
'complex64',
'complex128',
'uint16',
],
"pad",
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册