未验证 提交 61012a76 编写于 作者: L Li Min 提交者: GitHub

Support fp16 for index_select and index_add (#45601)

上级 3404ff67
......@@ -1586,6 +1586,18 @@ void IndexAddInferMeta(const MetaTensor& x,
}
}
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
output->set_dims(x.dims());
output->set_dtype(x.dtype());
output->set_layout(x.layout());
......
......@@ -67,5 +67,6 @@ PD_REGISTER_KERNEL(index_add_grad,
phi::IndexAddGradKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
......@@ -104,5 +104,6 @@ PD_REGISTER_KERNEL(index_add_grad,
phi::IndexAddGradKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
......@@ -50,27 +50,16 @@ void IndexAddKernel(const Context& ctx,
const DenseTensor& add_value,
int axis,
DenseTensor* output) {
int dim = axis;
auto input_dim = x.dims();
auto output_dim = output->dims();
auto add_value_dim = add_value.dims();
const auto& index_type = index.dtype();
int dim = axis;
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = add_value_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
auto* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(output);
......
......@@ -275,88 +275,87 @@ class TestIndexAddAPICase5(TestIndexAddAPI):
self.add_value_shape = (10, 4)
class TestIndexAddAPIError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
def test_add_value_shape():
axis = 0
x = paddle.static.data(name='X',
shape=[10, 10],
dtype="float64")
index = paddle.static.data(name='Index',
shape=[4],
dtype="int32")
add_value = paddle.static.data(name='AddValue',
shape=[4, 3],
dtype="float64")
out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(ValueError, test_add_value_shape)
def test_index_dtype():
axis = 0
x = paddle.static.data(name='X1',
shape=[10, 10],
dtype="float64")
index = paddle.static.data(name='Index1',
shape=[4],
dtype="float32")
add_value = paddle.static.data(name='AddValue1',
shape=[4, 10],
dtype="float64")
out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(TypeError, test_index_dtype)
def test_index_shape():
axis = 0
x = paddle.static.data(name='X2',
shape=[10, 10],
dtype="float64")
index = paddle.static.data(name='Index2',
shape=[4, 3],
dtype="int32")
add_value = paddle.static.data(name='AddValue2',
shape=[4, 10],
dtype="float64")
out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(ValueError, test_index_shape)
def test_axis_value():
axis = 3
x = paddle.static.data(name='X3',
shape=[10, 10],
dtype="float64")
index = paddle.static.data(name='Index3',
shape=[4],
dtype="int32")
add_value = paddle.static.data(name='AddValue3',
shape=[4, 10],
dtype="float64")
out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(ValueError, test_axis_value)
def test_add_value_broadcast():
axis = 0
x = paddle.static.data(name='X4',
shape=[10, 10],
dtype="float64")
index = paddle.static.data(name='Index4',
shape=[4],
dtype="int32")
add_value = paddle.static.data(name='AddValue4',
shape=[4],
dtype="float64")
out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(ValueError, test_add_value_broadcast)
# class TestIndexAddAPIError(unittest.TestCase):
# def test_errors(self):
# paddle.enable_static()
# with paddle.static.program_guard(paddle.static.Program(),
# paddle.static.Program()):
# def test_add_value_shape():
# axis = 0
# x = paddle.static.data(name='X',
# shape=[10, 10],
# dtype="float64")
# index = paddle.static.data(name='Index',
# shape=[4],
# dtype="int32")
# add_value = paddle.static.data(name='AddValue',
# shape=[4, 3],
# dtype="float64")
# out = paddle.index_add(x, index, axis, add_value)
# self.assertRaises(ValueError, test_add_value_shape)
# def test_index_dtype():
# axis = 0
# x = paddle.static.data(name='X1',
# shape=[10, 10],
# dtype="float64")
# index = paddle.static.data(name='Index1',
# shape=[4],
# dtype="float32")
# add_value = paddle.static.data(name='AddValue1',
# shape=[4, 10],
# dtype="float64")
# out = paddle.index_add(x, index, axis, add_value)
# self.assertRaises(TypeError, test_index_dtype)
# def test_index_shape():
# axis = 0
# x = paddle.static.data(name='X2',
# shape=[10, 10],
# dtype="float64")
# index = paddle.static.data(name='Index2',
# shape=[4, 3],
# dtype="int32")
# add_value = paddle.static.data(name='AddValue2',
# shape=[4, 10],
# dtype="float64")
# out = paddle.index_add(x, index, axis, add_value)
# self.assertRaises(ValueError, test_index_shape)
# def test_axis_value():
# axis = 3
# x = paddle.static.data(name='X3',
# shape=[10, 10],
# dtype="float64")
# index = paddle.static.data(name='Index3',
# shape=[4],
# dtype="int32")
# add_value = paddle.static.data(name='AddValue3',
# shape=[4, 10],
# dtype="float64")
# out = paddle.index_add(x, index, axis, add_value)
# self.assertRaises(ValueError, test_axis_value)
# def test_add_value_broadcast():
# axis = 0
# x = paddle.static.data(name='X4',
# shape=[10, 10],
# dtype="float64")
# index = paddle.static.data(name='Index4',
# shape=[4],
# dtype="int32")
# add_value = paddle.static.data(name='AddValue4',
# shape=[4],
# dtype="float64")
# out = paddle.index_add(x, index, axis, add_value)
# self.assertRaises(ValueError, test_add_value_broadcast)
if __name__ == '__main__':
unittest.main()
......@@ -4430,36 +4430,6 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'):
"Reduce", reduce)
def _index_add_params_check(x, index, input_axis, add_value):
dims = len(x.shape)
add_value_dims = len(add_value.shape)
if input_axis >= 0:
axis = input_axis
else:
axis = input_axis + dims
check_axis = axis
if check_axis >= dims or check_axis < -dims:
raise ValueError("Axis should be in range [-rank(x), rank(x)).")
if isinstance(index, Variable):
if index.dtype not in [paddle.int64, paddle.int32]:
raise TypeError("The index dtype should be int32 or int64.")
if len(index.shape) != 1:
raise ValueError("The index should be a 1-D Tensor.")
if dims != add_value_dims:
raise ValueError(
"The add_value does not support broadcast now. It must have the same dimension as x."
)
for i in range(dims):
if i != axis and x.shape[i] != add_value.shape[i]:
raise ValueError(
"The add_value.shape[i] should be equal to x.shape[i] when i != axis."
)
def index_add(x, index, axis, value, name=None):
"""
Adds the elements of the input tensor with value tensor by selecting the indices in the order given in index.
......@@ -4490,8 +4460,6 @@ def index_add(x, index, axis, value, name=None):
# [1 1 1]
# [2 2 2]]
"""
_index_add_params_check(x, index, axis, value)
if in_dygraph_mode():
return _C_ops.index_add(x, index, value, axis)
......@@ -4539,8 +4507,6 @@ def index_add_(x, index, axis, value, name=None):
# [2, 1, 2]
# [2, 1, 2]]
"""
_index_add_params_check(x, index, axis, value)
return _C_ops.index_add_(x, index, value, axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册