未验证 提交 1eb30775 编写于 作者: C chenxujun 提交者: GitHub

Add index_add, index_sample, put_along_axis, take_along_axis tests (#52572)

上级 afc2c598
......@@ -21,12 +21,14 @@ limitations under the License. */
namespace phi {
namespace funcs {
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t( \
func, double) Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
Instantiate_Template_Function_index_t(func, \
phi::dtype::bfloat16) \
Instantiate_Template_Function_index_t(func, unsigned char)
#define Instantiate_Template_Function_index_t(func, tensor_t) \
......
......@@ -105,5 +105,6 @@ PD_REGISTER_KERNEL(index_add_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
......@@ -123,5 +123,6 @@ PD_REGISTER_KERNEL(index_add,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
......@@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(put_along_axis_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -82,4 +82,5 @@ PD_REGISTER_KERNEL(put_along_axis,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -68,4 +68,5 @@ PD_REGISTER_KERNEL(take_along_axis_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -54,4 +54,5 @@ PD_REGISTER_KERNEL(take_along_axis,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,10 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import Program
from paddle.fluid import Program, core
def compute_index_add_ref(
......@@ -99,6 +99,69 @@ class TestIndexAddOp(OpTest):
self.check_grad(['X', 'AddValue'], 'Out')
class TestIndexAddFP16Op(TestIndexAddOp):
def init_dtype_type(self):
self.axis = 0
self.x_type = np.float16
self.index_type = np.int64
self.x_shape = (101, 3)
self.index_size = 3
self.add_value_shape = (3, 3)
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestIndexAddBF16Op(OpTest):
def setUp(self):
self.python_api = raw_index_add
self.op_type = "index_add"
self.init_dtype_type()
index_np = np.random.randint(
low=0, high=self.x_shape[self.axis], size=self.index_size
)
x_np = np.random.random(self.x_shape).astype(self.x_type)
add_value_np = np.random.random(self.add_value_shape).astype(
self.x_type
)
self.inputs = {
'X': convert_float_to_uint16(x_np),
'Index': index_np,
'AddValue': convert_float_to_uint16(add_value_np),
}
self.attrs = {'axis': self.axis}
out = compute_index_add_ref(
self.axis,
self.x_shape,
x_np,
self.add_value_shape,
add_value_np,
self.index_size,
index_np,
)
self.outputs = {'Out': convert_float_to_uint16(out)}
self.place = core.CUDAPlace(0)
def init_dtype_type(self):
self.axis = 0
self.x_type = np.float32
self.index_type = np.int64
self.x_shape = (101, 3)
self.index_size = 3
self.add_value_shape = (3, 3)
self.dtype = np.uint16
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X', 'AddValue'], 'Out')
class TestIndexAddAPI(unittest.TestCase):
def setUp(self):
self.setType()
......
......@@ -15,10 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
from paddle.fluid import core
class TestIndexSampleOp(OpTest):
......@@ -121,6 +122,49 @@ class TestCase6(TestIndexSampleOp):
self.index_type = "int64"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestIndexSampleBF16Op(OpTest):
def setUp(self):
self.op_type = "index_sample"
self.python_api = paddle.index_sample
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
indexnp = np.random.randint(
low=0, high=self.x_shape[1], size=self.index_shape
).astype(self.index_type)
self.inputs = {'X': xnp, 'Index': indexnp}
index_array = []
for i in range(self.index_shape[0]):
for j in indexnp[i]:
index_array.append(xnp[i, j])
index_array = np.array(index_array).astype(self.x_type)
out = np.reshape(index_array, self.index_shape)
self.outputs = {'Out': out}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float32"
self.dtype = np.uint16
self.index_shape = (10, 10)
self.index_type = "int32"
class TestIndexSampleShape(unittest.TestCase):
def test_shape(self):
paddle.enable_static()
......
......@@ -16,7 +16,7 @@ import copy
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.framework import core
......@@ -28,19 +28,18 @@ class TestPutAlongAxisOp(OpTest):
def setUp(self):
self.init_data()
self.reduce_op = "assign"
self.dtype = 'float64'
self.op_type = "put_along_axis"
self.python_api = paddle.tensor.put_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
# numpy put_along_axis is an inplace opearion.
# numpy put_along_axis is an inplace operation.
self.xnp_result = copy.deepcopy(self.xnp)
np.put_along_axis(self.xnp_result, self.index, self.value, self.axis)
self.target = self.xnp_result
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.braodcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.braodcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.braodcast_shape)
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
......@@ -56,6 +55,7 @@ class TestPutAlongAxisOp(OpTest):
self.check_grad(["Input", "Value"], "Result")
def init_data(self):
self.dtype = 'float64'
self.x_type = "float64"
self.x_shape = (10, 10, 10)
self.value_type = "float64"
......@@ -66,6 +66,71 @@ class TestPutAlongAxisOp(OpTest):
self.axis_type = "int64"
class TestPutAlongAxisFP16Op(TestPutAlongAxisOp):
def init_data(self):
self.dtype = np.float16
self.x_type = "float16"
self.x_shape = (10, 10, 10)
self.value_type = "float16"
self.value = np.array([99]).astype(self.value_type)
self.index_type = "int32"
self.index = np.array([[[0]]]).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestPutAlongAxisBF16Op(OpTest):
def setUp(self):
self.init_data()
self.reduce_op = "assign"
self.op_type = "put_along_axis"
self.python_api = paddle.tensor.put_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
# numpy put_along_axis is an inplace operation.
self.xnp_result = copy.deepcopy(self.xnp)
np.put_along_axis(self.xnp_result, self.index, self.value, self.axis)
self.target = self.xnp_result
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
'Value': self.value_broadcast,
}
self.attrs = {'Axis': self.axis, 'Reduce': self.reduce_op}
self.outputs = {'Result': self.target}
self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input'])
self.inputs['Value'] = convert_float_to_uint16(self.inputs['Value'])
self.outputs['Result'] = convert_float_to_uint16(self.outputs['Result'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ["Input", "Value"], "Result")
def init_data(self):
self.dtype = np.uint16
self.x_type = "float32"
self.x_shape = (10, 10, 10)
self.value_type = "float32"
self.value = np.array([99]).astype(self.value_type)
self.index_type = "int32"
self.index = np.array([[[0]]]).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"
class TestPutAlongAxisAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.framework import core
......@@ -32,8 +32,8 @@ class TestTakeAlongAxisOp(OpTest):
self.target = np.take_along_axis(self.xnp, self.index, self.axis)
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.braodcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.braodcast_shape)
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
......@@ -58,6 +58,64 @@ class TestTakeAlongAxisOp(OpTest):
self.axis_type = "int64"
class TestTakeAlongAxisFP16Op(TestTakeAlongAxisOp):
def init_data(self):
self.dtype = np.float16
self.x_type = "float16"
self.x_shape = (5, 5, 5)
self.index_type = "int32"
self.index = np.array([[[1]], [[1]], [[2]], [[4]], [[3]]]).astype(
self.index_type
)
self.axis = 2
self.axis_type = "int64"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestTakeAlongAxisBF16Op(OpTest):
def setUp(self):
self.init_data()
self.op_type = "take_along_axis"
self.python_api = paddle.tensor.take_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
self.target = np.take_along_axis(self.xnp, self.index, self.axis)
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
}
self.attrs = {'Axis': self.axis}
self.outputs = {'Result': self.target}
self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input'])
self.outputs['Result'] = convert_float_to_uint16(self.outputs['Result'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['Input'], 'Result')
def init_data(self):
self.dtype = np.uint16
self.x_type = "float32"
self.x_shape = (5, 5, 5)
self.index_type = "int32"
self.index = np.array([[[1]], [[1]], [[2]], [[4]], [[3]]]).astype(
self.index_type
)
self.axis = 2
self.axis_type = "int64"
class TestCase1(TestTakeAlongAxisOp):
def init_data(self):
self.x_type = "float64"
......
......@@ -4540,7 +4540,15 @@ def take_along_axis(arr, indices, axis):
check_variable_and_dtype(
arr,
'x',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint8',
'uint16',
],
'take_along_axis',
)
check_variable_and_dtype(
......@@ -4612,7 +4620,15 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
check_variable_and_dtype(
arr,
'x',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint8',
'uint16',
],
'put_along_axis',
)
check_variable_and_dtype(
......@@ -4694,7 +4710,7 @@ def index_add(x, index, axis, value, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'paddle.tensor.manipulation.index_add',
)
check_variable_and_dtype(
......@@ -4706,7 +4722,7 @@ def index_add(x, index, axis, value, name=None):
check_variable_and_dtype(
value,
'add_value',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'paddle.tensor.manipulation.index_add',
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册