未验证 提交 f06dd08d 编写于 作者: Y Yuang Liu 提交者: GitHub

【AMP OP&Test】Support bf16 scatter and scatter_nd_add, add bf16/fp16 ut. (#51689)

上级 80472116
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/scatter_grad_kernel.h" #include "paddle/phi/kernels/scatter_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/gather.cu.h"
...@@ -72,4 +73,5 @@ PD_REGISTER_KERNEL(scatter_grad, ...@@ -72,4 +73,5 @@ PD_REGISTER_KERNEL(scatter_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/scatter_kernel.h" #include "paddle/phi/kernels/scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
...@@ -60,4 +61,5 @@ PD_REGISTER_KERNEL(scatter, ...@@ -60,4 +61,5 @@ PD_REGISTER_KERNEL(scatter,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/scatter_nd_add_grad_kernel.h" #include "paddle/phi/kernels/scatter_nd_add_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/gather.cu.h"
...@@ -53,4 +54,5 @@ PD_REGISTER_KERNEL(scatter_nd_add_grad, ...@@ -53,4 +54,5 @@ PD_REGISTER_KERNEL(scatter_nd_add_grad,
double, double,
int64_t, int64_t,
int, int,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/scatter_nd_add_kernel.h" #include "paddle/phi/kernels/scatter_nd_add_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
...@@ -56,4 +57,5 @@ PD_REGISTER_KERNEL(scatter_nd_add, ...@@ -56,4 +57,5 @@ PD_REGISTER_KERNEL(scatter_nd_add,
double, double,
int64_t, int64_t,
int, int,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
...@@ -68,14 +69,27 @@ class TestScatterNdAddSimpleOp(OpTest): ...@@ -68,14 +69,27 @@ class TestScatterNdAddSimpleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add self.python_api = paddle.scatter_nd_add
ref_np = np.random.random([100]).astype("float64") self._set_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
ref_np = np.random.random([100]).astype(target_dtype)
index_np = np.random.randint(0, 100, [100, 1]).astype("int32") index_np = np.random.randint(0, 100, [100, 1]).astype("int32")
updates_np = np.random.random([100]).astype("float64") updates_np = np.random.random([100]).astype(target_dtype)
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
expect_np = convert_float_to_uint16(expect_np)
self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': expect_np} self.outputs = {'Out': expect_np}
def _set_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -83,6 +97,41 @@ class TestScatterNdAddSimpleOp(OpTest): ...@@ -83,6 +97,41 @@ class TestScatterNdAddSimpleOp(OpTest):
self.check_grad(['X', 'Updates'], 'Out', check_eager=True) self.check_grad(['X', 'Updates'], 'Out', check_eager=True)
class TestScatterNdAddSimpleFP16Op(TestScatterNdAddSimpleOp):
"""
A simple example
"""
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterNdAddSimpleBF16Op(TestScatterNdAddSimpleOp):
"""
A simple example
"""
def _set_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=True
)
class TestScatterNdAddWithEmptyIndex(OpTest): class TestScatterNdAddWithEmptyIndex(OpTest):
""" """
Index has empty element Index has empty element
...@@ -91,15 +140,30 @@ class TestScatterNdAddWithEmptyIndex(OpTest): ...@@ -91,15 +140,30 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add self.python_api = paddle.scatter_nd_add
ref_np = np.random.random((10, 10)).astype("float64") self._set_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
ref_np = np.random.random((10, 10)).astype(target_dtype)
index_np = np.array([[], []]).astype("int32") index_np = np.array([[], []]).astype("int32")
updates_np = np.random.random((2, 10, 10)).astype("float64") updates_np = np.random.random((2, 10, 10)).astype(target_dtype)
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
expect_np = convert_float_to_uint16(expect_np)
self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': expect_np} self.outputs = {'Out': expect_np}
def _set_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -107,6 +171,41 @@ class TestScatterNdAddWithEmptyIndex(OpTest): ...@@ -107,6 +171,41 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self.check_grad(['X', 'Updates'], 'Out', check_eager=True) self.check_grad(['X', 'Updates'], 'Out', check_eager=True)
class TestScatterNdAddWithEmptyIndexFP16(TestScatterNdAddWithEmptyIndex):
"""
Index has empty element
"""
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterNdAddWithEmptyIndexBF16(TestScatterNdAddWithEmptyIndex):
"""
Index has empty element
"""
def _set_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=True
)
class TestScatterNdAddWithHighRankSame(OpTest): class TestScatterNdAddWithHighRankSame(OpTest):
""" """
Both Index and X have high rank, and Rank(Index) = Rank(X) Both Index and X have high rank, and Rank(Index) = Rank(X)
...@@ -115,18 +214,33 @@ class TestScatterNdAddWithHighRankSame(OpTest): ...@@ -115,18 +214,33 @@ class TestScatterNdAddWithHighRankSame(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
self.python_api = paddle.scatter_nd_add self.python_api = paddle.scatter_nd_add
self._set_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
shape = (3, 2, 2, 1, 10) shape = (3, 2, 2, 1, 10)
ref_np = np.random.rand(*shape).astype("float64") ref_np = np.random.rand(*shape).astype(target_dtype)
index_np = np.vstack( index_np = np.vstack(
[np.random.randint(0, s, size=100) for s in shape] [np.random.randint(0, s, size=100) for s in shape]
).T.astype("int32") ).T.astype("int32")
update_shape = judge_update_shape(ref_np, index_np) update_shape = judge_update_shape(ref_np, index_np)
updates_np = np.random.rand(*update_shape).astype("float64") updates_np = np.random.rand(*update_shape).astype(target_dtype)
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
expect_np = convert_float_to_uint16(expect_np)
self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': expect_np} self.outputs = {'Out': expect_np}
def _set_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -134,6 +248,41 @@ class TestScatterNdAddWithHighRankSame(OpTest): ...@@ -134,6 +248,41 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self.check_grad(['X', 'Updates'], 'Out', check_eager=True) self.check_grad(['X', 'Updates'], 'Out', check_eager=True)
class TestScatterNdAddWithHighRankSameFP16(TestScatterNdAddWithHighRankSame):
"""
Both Index and X have high rank, and Rank(Index) = Rank(X)
"""
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterNdAddWithHighRankSameBF16(TestScatterNdAddWithHighRankSame):
"""
Both Index and X have high rank, and Rank(Index) = Rank(X)
"""
def _set_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=True
)
class TestScatterNdAddWithHighRankDiff(OpTest): class TestScatterNdAddWithHighRankDiff(OpTest):
""" """
Both Index and X have high rank, and Rank(Index) < Rank(X) Both Index and X have high rank, and Rank(Index) < Rank(X)
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -28,14 +28,23 @@ class TestScatterOp(OpTest): ...@@ -28,14 +28,23 @@ class TestScatterOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
ref_np = np.ones((3, 50)).astype("float32") self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 50)).astype(target_dtype)
index_np = np.array([1, 2]).astype("int32") index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 50)).astype("float32") updates_np = np.random.random((2, 50)).astype(target_dtype)
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] = updates_np output_np[index_np] = updates_np
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
output_np = convert_float_to_uint16(output_np)
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def _set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -43,19 +52,55 @@ class TestScatterOp(OpTest): ...@@ -43,19 +52,55 @@ class TestScatterOp(OpTest):
self.check_grad(["X", "Updates"], "Out", check_eager=False) self.check_grad(["X", "Updates"], "Out", check_eager=False)
class TestScatterFP16Op(TestScatterOp):
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op(TestScatterOp):
def _set_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=False
)
class TestScatterOp0(OpTest): class TestScatterOp0(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
ref_np = np.ones((3, 3)).astype("float32") self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
index_np = np.array([1, 2]).astype("int32") index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype(target_dtype)
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] = updates_np output_np[index_np] = updates_np
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
output_np = convert_float_to_uint16(output_np)
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.attrs = {'overwrite': True} self.attrs = {'overwrite': True}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def _set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -63,22 +108,58 @@ class TestScatterOp0(OpTest): ...@@ -63,22 +108,58 @@ class TestScatterOp0(OpTest):
self.check_grad(["X", "Updates"], "Out", check_eager=False) self.check_grad(["X", "Updates"], "Out", check_eager=False)
class TestScatterFP16Op0(TestScatterOp0):
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op0(TestScatterOp0):
def _set_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=False
)
class TestScatterOp1(OpTest): class TestScatterOp1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
ref_np = np.ones((3, 3)).astype("float32") self._set_dtype()
zeros_np = np.zeros([2, 3]).astype('float32') target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
zeros_np = np.zeros([2, 3]).astype(target_dtype)
index_np = np.array([1, 1]).astype("int32") index_np = np.array([1, 1]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype(target_dtype)
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] = zeros_np output_np[index_np] = zeros_np
for i in range(0, len(index_np)): for i in range(0, len(index_np)):
output_np[index_np[i]] += updates_np[i] output_np[index_np[i]] += updates_np[i]
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
output_np = convert_float_to_uint16(output_np)
self.attrs = {'overwrite': False} self.attrs = {'overwrite': False}
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def _set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -86,6 +167,33 @@ class TestScatterOp1(OpTest): ...@@ -86,6 +167,33 @@ class TestScatterOp1(OpTest):
self.check_grad(["X", "Updates"], "Out", check_eager=False) self.check_grad(["X", "Updates"], "Out", check_eager=False)
class TestScatterFP16Op1(TestScatterOp1):
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op1(TestScatterOp1):
def _set_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=False
)
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
) )
...@@ -93,14 +201,23 @@ class TestScatterOp2(OpTest): ...@@ -93,14 +201,23 @@ class TestScatterOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
ref_np = np.ones((3, 3)).astype("float32") self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
index_np = np.array([1, 2]).astype("int32") index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype(target_dtype)
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] = updates_np output_np[index_np] = updates_np
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
output_np = convert_float_to_uint16(output_np)
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def _set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -114,6 +231,24 @@ class TestScatterOp2(OpTest): ...@@ -114,6 +231,24 @@ class TestScatterOp2(OpTest):
) )
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op2(TestScatterOp2):
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op2(TestScatterOp2):
def _set_dtype(self):
self.dtype = np.uint16
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
) )
...@@ -121,18 +256,27 @@ class TestScatterOp3(OpTest): ...@@ -121,18 +256,27 @@ class TestScatterOp3(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
ref_np = np.ones((3, 3)).astype("float32") self._set_dtype()
zeros_np = np.zeros([2, 3]).astype('float32') target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
zeros_np = np.zeros([2, 3]).astype(target_dtype)
index_np = np.array([1, 1]).astype("int32") index_np = np.array([1, 1]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype(target_dtype)
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] = zeros_np output_np[index_np] = zeros_np
for i in range(0, len(index_np)): for i in range(0, len(index_np)):
output_np[index_np[i]] += updates_np[i] output_np[index_np[i]] += updates_np[i]
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
output_np = convert_float_to_uint16(output_np)
self.attrs = {'overwrite': False} self.attrs = {'overwrite': False}
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def _set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -146,18 +290,45 @@ class TestScatterOp3(OpTest): ...@@ -146,18 +290,45 @@ class TestScatterOp3(OpTest):
) )
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op3(TestScatterOp3):
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op3(TestScatterOp3):
def _set_dtype(self):
self.dtype = np.uint16
class TestScatterOp4(OpTest): class TestScatterOp4(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
ref_np = np.ones((3, 3)).astype("float32") self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
index_np = np.array([1, 2]).astype("int64") index_np = np.array([1, 2]).astype("int64")
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype(target_dtype)
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] = updates_np output_np[index_np] = updates_np
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
output_np = convert_float_to_uint16(output_np)
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def _set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -165,6 +336,33 @@ class TestScatterOp4(OpTest): ...@@ -165,6 +336,33 @@ class TestScatterOp4(OpTest):
self.check_grad(['X', 'Updates'], 'Out', check_eager=False) self.check_grad(['X', 'Updates'], 'Out', check_eager=False)
class TestScatterFP16Op4(TestScatterOp4):
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op4(TestScatterOp4):
def _set_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=False
)
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
) )
...@@ -172,14 +370,23 @@ class TestScatterOp5(OpTest): ...@@ -172,14 +370,23 @@ class TestScatterOp5(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
ref_np = np.ones((3, 3)).astype("float32") self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
index_np = np.array([1, 2]).astype("int64") index_np = np.array([1, 2]).astype("int64")
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype(target_dtype)
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] = updates_np output_np[index_np] = updates_np
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
output_np = convert_float_to_uint16(output_np)
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def _set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -193,18 +400,45 @@ class TestScatterOp5(OpTest): ...@@ -193,18 +400,45 @@ class TestScatterOp5(OpTest):
) )
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestScatterFP16Op5(TestScatterOp5):
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op5(TestScatterOp5):
def _set_dtype(self):
self.dtype = np.uint16
class TestScatterOp6(OpTest): class TestScatterOp6(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter" self.op_type = "scatter"
self.python_api = paddle.scatter self.python_api = paddle.scatter
ref_np = np.ones((3, 50)).astype("float32") self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 50)).astype(target_dtype)
index_np = np.array([[1], [2]]).astype("int32") index_np = np.array([[1], [2]]).astype("int32")
updates_np = np.random.random((2, 50)).astype("float32") updates_np = np.random.random((2, 50)).astype(target_dtype)
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[np.array([1, 2]).astype("int32")] = updates_np output_np[np.array([1, 2]).astype("int32")] = updates_np
if self.dtype == np.uint16:
ref_np = convert_float_to_uint16(ref_np)
updates_np = convert_float_to_uint16(updates_np)
output_np = convert_float_to_uint16(output_np)
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def _set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -212,6 +446,33 @@ class TestScatterOp6(OpTest): ...@@ -212,6 +446,33 @@ class TestScatterOp6(OpTest):
self.check_grad(["X", "Updates"], "Out", check_eager=False) self.check_grad(["X", "Updates"], "Out", check_eager=False)
class TestScatterFP16Op6(TestScatterOp6):
def _set_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestScatterBF16Op6(TestScatterOp6):
def _set_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=False
)
class TestScatterAPI(unittest.TestCase): class TestScatterAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.places = [fluid.CPUPlace()] self.places = [fluid.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册