未验证 提交 1550348e 编写于 作者: L LoneRanger 提交者: GitHub

【PaddlePaddle Hackathon 4】No.63 fix temporal_shift and conj (#51532)

* add fp16 and bfp16 for temporalshift

* add fp16 and bfp16 for complex

* fix bug

* fix bug

* add fp16 and bf16 for conj

* fix bug

* fix bug

* Update complex_kernel.h

fix bug

* Update temporal_shift_grad_kernel.h

fix bug

* Update temporal_shift_kernel.h

fix bug
上级 a82911a5
...@@ -26,6 +26,7 @@ PD_REGISTER_KERNEL(conj, ...@@ -26,6 +26,7 @@ PD_REGISTER_KERNEL(conj,
ALL_LAYOUT, ALL_LAYOUT,
phi::ConjKernel, phi::ConjKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>, phi::dtype::complex<double>,
float, float,
......
...@@ -146,4 +146,5 @@ PD_REGISTER_KERNEL(temporal_shift_grad, ...@@ -146,4 +146,5 @@ PD_REGISTER_KERNEL(temporal_shift_grad,
phi::TemporalShiftGradKernel, phi::TemporalShiftGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -146,4 +146,5 @@ PD_REGISTER_KERNEL(temporal_shift, ...@@ -146,4 +146,5 @@ PD_REGISTER_KERNEL(temporal_shift,
phi::TemporalShiftKernel, phi::TemporalShiftKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -20,9 +20,10 @@ import numpy as np ...@@ -20,9 +20,10 @@ import numpy as np
import paddle import paddle
sys.path.append("..") sys.path.append("..")
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
from numpy.random import random as rand from numpy.random import random as rand
import paddle.fluid.core as core
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import paddle.static as static import paddle.static as static
...@@ -147,5 +148,43 @@ class Testfp16ConjOp(unittest.TestCase): ...@@ -147,5 +148,43 @@ class Testfp16ConjOp(unittest.TestCase):
out = exe.run(feed={'x': input_x}, fetch_list=[out]) out = exe.run(feed={'x': input_x}, fetch_list=[out])
class TestConjFP16OP(TestConjOp):
def init_dtype_type(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestConjBF16(OpTest):
def setUp(self):
self.op_type = "conj"
self.python_api = paddle.tensor.conj
self.init_dtype_type()
self.init_input_output()
def init_dtype_type(self):
self.dtype = np.uint16
def init_input_output(self):
x = (
np.random.random((12, 14)) + 1j * np.random.random((12, 14))
).astype(np.float32)
out = np.conj(x)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -44,6 +44,7 @@ def temporal_shift(x, seg_num, shift_ratio, data_format): ...@@ -44,6 +44,7 @@ def temporal_shift(x, seg_num, shift_ratio, data_format):
class TestTemporalShift(OpTest): class TestTemporalShift(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
self.init_dtype()
self.op_type = 'temporal_shift' self.op_type = 'temporal_shift'
self.python_api = paddle.nn.functional.temporal_shift self.python_api = paddle.nn.functional.temporal_shift
x = np.random.random(self.x_shape).astype(self.dtype) x = np.random.random(self.x_shape).astype(self.dtype)
...@@ -64,6 +65,9 @@ class TestTemporalShift(OpTest): ...@@ -64,6 +65,9 @@ class TestTemporalShift(OpTest):
self.outputs = {"Out": output} self.outputs = {"Out": output}
self.python_out_sig = ["Out"] self.python_out_sig = ["Out"]
def init_dtype(self):
self.dtype = 'float64'
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -74,7 +78,6 @@ class TestTemporalShift(OpTest): ...@@ -74,7 +78,6 @@ class TestTemporalShift(OpTest):
self.x_shape = (6, 4, 4, 4) self.x_shape = (6, 4, 4, 4)
self.seg_num = 3 self.seg_num = 3
self.shift_ratio = 0.25 self.shift_ratio = 0.25
self.dtype = 'float64'
self.data_format = 'NCHW' self.data_format = 'NCHW'
...@@ -174,6 +177,56 @@ class TestTemporalShiftAPI(unittest.TestCase): ...@@ -174,6 +177,56 @@ class TestTemporalShiftAPI(unittest.TestCase):
self.assertRaises(ValueError, attr_data_format) self.assertRaises(ValueError, attr_data_format)
class TestTemporalShiftFP16OP(TestTemporalShift):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestTemporalShiftBF16(OpTest):
def initTestCase(self):
self.x_shape = (3, 10, 5, 5)
self.seg_num = 1
self.shift_ratio = 0.3
self.dtype = np.uint16
self.data_format = 'NCHW'
def setUp(self):
self.initTestCase()
self.op_type = 'temporal_shift'
self.python_api = paddle.nn.functional.temporal_shift
x = np.random.random(self.x_shape).astype(np.float32)
self.attrs = {
"seg_num": self.seg_num,
"shift_ratio": self.shift_ratio,
"data_format": self.data_format,
}
self.inputs = {
"X": convert_float_to_uint16(x),
}
output = temporal_shift(
x, self.seg_num, self.shift_ratio, self.data_format
)
self.outputs = {"Out": convert_float_to_uint16(output)}
self.python_out_sig = ["Out"]
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad_ignore_uv(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册