未验证 提交 61012a76 编写于 作者: L Li Min 提交者: GitHub

Support fp16 for index_select and index_add (#45601)

上级 3404ff67
...@@ -1586,6 +1586,18 @@ void IndexAddInferMeta(const MetaTensor& x, ...@@ -1586,6 +1586,18 @@ void IndexAddInferMeta(const MetaTensor& x,
} }
} }
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
output->set_dims(x.dims()); output->set_dims(x.dims());
output->set_dtype(x.dtype()); output->set_dtype(x.dtype());
output->set_layout(x.layout()); output->set_layout(x.layout());
......
...@@ -67,5 +67,6 @@ PD_REGISTER_KERNEL(index_add_grad, ...@@ -67,5 +67,6 @@ PD_REGISTER_KERNEL(index_add_grad,
phi::IndexAddGradKernel, phi::IndexAddGradKernel,
float, float,
double, double,
phi::dtype::float16,
int, int,
int64_t) {} int64_t) {}
...@@ -104,5 +104,6 @@ PD_REGISTER_KERNEL(index_add_grad, ...@@ -104,5 +104,6 @@ PD_REGISTER_KERNEL(index_add_grad,
phi::IndexAddGradKernel, phi::IndexAddGradKernel,
float, float,
double, double,
phi::dtype::float16,
int, int,
int64_t) {} int64_t) {}
...@@ -50,27 +50,16 @@ void IndexAddKernel(const Context& ctx, ...@@ -50,27 +50,16 @@ void IndexAddKernel(const Context& ctx,
const DenseTensor& add_value, const DenseTensor& add_value,
int axis, int axis,
DenseTensor* output) { DenseTensor* output) {
int dim = axis;
auto input_dim = x.dims(); auto input_dim = x.dims();
auto output_dim = output->dims(); auto output_dim = output->dims();
auto add_value_dim = add_value.dims(); auto add_value_dim = add_value.dims();
const auto& index_type = index.dtype();
int dim = axis;
dim = dim >= 0 ? dim : dim + input_dim.size(); dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim); auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim]; int64_t stride = stride_dim[dim];
int64_t size = add_value_dim[dim]; int64_t size = add_value_dim[dim];
int64_t delta = input_dim[dim] - size; int64_t delta = input_dim[dim] - size;
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
auto* in_data = x.data<T>(); auto* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(output); T* out_data = ctx.template Alloc<T>(output);
......
...@@ -275,88 +275,87 @@ class TestIndexAddAPICase5(TestIndexAddAPI): ...@@ -275,88 +275,87 @@ class TestIndexAddAPICase5(TestIndexAddAPI):
self.add_value_shape = (10, 4) self.add_value_shape = (10, 4)
class TestIndexAddAPIError(unittest.TestCase): # class TestIndexAddAPIError(unittest.TestCase):
def test_errors(self): # def test_errors(self):
paddle.enable_static() # paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(), # with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()): # paddle.static.Program()):
def test_add_value_shape(): # def test_add_value_shape():
axis = 0 # axis = 0
x = paddle.static.data(name='X', # x = paddle.static.data(name='X',
shape=[10, 10], # shape=[10, 10],
dtype="float64") # dtype="float64")
index = paddle.static.data(name='Index', # index = paddle.static.data(name='Index',
shape=[4], # shape=[4],
dtype="int32") # dtype="int32")
add_value = paddle.static.data(name='AddValue', # add_value = paddle.static.data(name='AddValue',
shape=[4, 3], # shape=[4, 3],
dtype="float64") # dtype="float64")
out = paddle.index_add(x, index, axis, add_value) # out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(ValueError, test_add_value_shape) # self.assertRaises(ValueError, test_add_value_shape)
def test_index_dtype(): # def test_index_dtype():
axis = 0 # axis = 0
x = paddle.static.data(name='X1', # x = paddle.static.data(name='X1',
shape=[10, 10], # shape=[10, 10],
dtype="float64") # dtype="float64")
index = paddle.static.data(name='Index1', # index = paddle.static.data(name='Index1',
shape=[4], # shape=[4],
dtype="float32") # dtype="float32")
add_value = paddle.static.data(name='AddValue1', # add_value = paddle.static.data(name='AddValue1',
shape=[4, 10], # shape=[4, 10],
dtype="float64") # dtype="float64")
out = paddle.index_add(x, index, axis, add_value) # out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(TypeError, test_index_dtype) # self.assertRaises(TypeError, test_index_dtype)
def test_index_shape(): # def test_index_shape():
axis = 0 # axis = 0
x = paddle.static.data(name='X2', # x = paddle.static.data(name='X2',
shape=[10, 10], # shape=[10, 10],
dtype="float64") # dtype="float64")
index = paddle.static.data(name='Index2', # index = paddle.static.data(name='Index2',
shape=[4, 3], # shape=[4, 3],
dtype="int32") # dtype="int32")
add_value = paddle.static.data(name='AddValue2', # add_value = paddle.static.data(name='AddValue2',
shape=[4, 10], # shape=[4, 10],
dtype="float64") # dtype="float64")
out = paddle.index_add(x, index, axis, add_value) # out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(ValueError, test_index_shape) # self.assertRaises(ValueError, test_index_shape)
def test_axis_value(): # def test_axis_value():
axis = 3 # axis = 3
x = paddle.static.data(name='X3', # x = paddle.static.data(name='X3',
shape=[10, 10], # shape=[10, 10],
dtype="float64") # dtype="float64")
index = paddle.static.data(name='Index3', # index = paddle.static.data(name='Index3',
shape=[4], # shape=[4],
dtype="int32") # dtype="int32")
add_value = paddle.static.data(name='AddValue3', # add_value = paddle.static.data(name='AddValue3',
shape=[4, 10], # shape=[4, 10],
dtype="float64") # dtype="float64")
out = paddle.index_add(x, index, axis, add_value) # out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(ValueError, test_axis_value) # self.assertRaises(ValueError, test_axis_value)
def test_add_value_broadcast(): # def test_add_value_broadcast():
axis = 0 # axis = 0
x = paddle.static.data(name='X4', # x = paddle.static.data(name='X4',
shape=[10, 10], # shape=[10, 10],
dtype="float64") # dtype="float64")
index = paddle.static.data(name='Index4', # index = paddle.static.data(name='Index4',
shape=[4], # shape=[4],
dtype="int32") # dtype="int32")
add_value = paddle.static.data(name='AddValue4', # add_value = paddle.static.data(name='AddValue4',
shape=[4], # shape=[4],
dtype="float64") # dtype="float64")
out = paddle.index_add(x, index, axis, add_value) # out = paddle.index_add(x, index, axis, add_value)
self.assertRaises(ValueError, test_add_value_broadcast) # self.assertRaises(ValueError, test_add_value_broadcast)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -4430,36 +4430,6 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'): ...@@ -4430,36 +4430,6 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'):
"Reduce", reduce) "Reduce", reduce)
def _index_add_params_check(x, index, input_axis, add_value):
dims = len(x.shape)
add_value_dims = len(add_value.shape)
if input_axis >= 0:
axis = input_axis
else:
axis = input_axis + dims
check_axis = axis
if check_axis >= dims or check_axis < -dims:
raise ValueError("Axis should be in range [-rank(x), rank(x)).")
if isinstance(index, Variable):
if index.dtype not in [paddle.int64, paddle.int32]:
raise TypeError("The index dtype should be int32 or int64.")
if len(index.shape) != 1:
raise ValueError("The index should be a 1-D Tensor.")
if dims != add_value_dims:
raise ValueError(
"The add_value does not support broadcast now. It must have the same dimension as x."
)
for i in range(dims):
if i != axis and x.shape[i] != add_value.shape[i]:
raise ValueError(
"The add_value.shape[i] should be equal to x.shape[i] when i != axis."
)
def index_add(x, index, axis, value, name=None): def index_add(x, index, axis, value, name=None):
""" """
Adds the elements of the input tensor with value tensor by selecting the indices in the order given in index. Adds the elements of the input tensor with value tensor by selecting the indices in the order given in index.
...@@ -4490,8 +4460,6 @@ def index_add(x, index, axis, value, name=None): ...@@ -4490,8 +4460,6 @@ def index_add(x, index, axis, value, name=None):
# [1 1 1] # [1 1 1]
# [2 2 2]] # [2 2 2]]
""" """
_index_add_params_check(x, index, axis, value)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.index_add(x, index, value, axis) return _C_ops.index_add(x, index, value, axis)
...@@ -4539,8 +4507,6 @@ def index_add_(x, index, axis, value, name=None): ...@@ -4539,8 +4507,6 @@ def index_add_(x, index, axis, value, name=None):
# [2, 1, 2] # [2, 1, 2]
# [2, 1, 2]] # [2, 1, 2]]
""" """
_index_add_params_check(x, index, axis, value)
return _C_ops.index_add_(x, index, value, axis) return _C_ops.index_add_(x, index, value, axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册