未验证 提交 41a64351 编写于 作者: H huangxu96 提交者: GitHub

Take/Put_along_axis more input size support (#39072)

Support the cases that the indices shape size is larger than the arr shape size
上级 809a10b6
...@@ -82,7 +82,7 @@ class TestPutAlongAxisAPI(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestPutAlongAxisAPI(unittest.TestCase):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0)) self.place.append(paddle.CUDAPlace(0))
def test_api_static_case1(self): def test_api_static(self):
paddle.enable_static() paddle.enable_static()
def run(place): def run(place):
...@@ -110,7 +110,7 @@ class TestPutAlongAxisAPI(unittest.TestCase): ...@@ -110,7 +110,7 @@ class TestPutAlongAxisAPI(unittest.TestCase):
for place in self.place: for place in self.place:
run(place) run(place)
def test_api_dygraph_case1(self): def test_api_dygraph(self):
def run(place): def run(place):
paddle.disable_static(place) paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np) x_tensor = paddle.to_tensor(self.x_np)
...@@ -137,33 +137,7 @@ class TestPutAlongAxisAPI(unittest.TestCase): ...@@ -137,33 +137,7 @@ class TestPutAlongAxisAPI(unittest.TestCase):
for place in self.place: for place in self.place:
run(place) run(place)
def test_api_dygraph_case2(self): def test_inplace_dygraph(self):
def run(place):
paddle.disable_static(place)
self.shape = [2, 2]
self.index_shape = [2, 2]
self.index_np = np.array([[0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
x_tensor = paddle.to_tensor(self.x_np)
index_tensor = paddle.to_tensor(self.index_np)
value_tensor = paddle.to_tensor(self.value_np)
out = paddle.put_along_axis(x_tensor, index_tensor, value_tensor,
self.axis)
np.array(
np.put_along_axis(self.x_np, self.index_np, self.value_np,
self.axis))
out_ref = self.x_np
self.assertEqual(
np.allclose(
out.numpy(), out_ref, rtol=1e-03), True)
paddle.enable_static()
for place in self.place:
run(place)
def test_inplace_dygraph_case3(self):
def run(place): def run(place):
paddle.disable_static(place) paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np) x_tensor = paddle.to_tensor(self.x_np)
...@@ -186,6 +160,42 @@ class TestPutAlongAxisAPI(unittest.TestCase): ...@@ -186,6 +160,42 @@ class TestPutAlongAxisAPI(unittest.TestCase):
run(place) run(place)
class TestPutAlongAxisAPICase2(TestPutAlongAxisAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 2]
self.index_shape = [2, 2]
self.index_np = np.array([[0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 0
self.value_np = 99.0
self.value_shape = [1]
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
class TestPutAlongAxisAPICase3(TestPutAlongAxisAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 2]
self.index_shape = [4, 2]
self.index_np = np.array(
[[0, 0], [1, 0], [0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 0
self.value_np = 99.0
self.value_shape = [1]
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_inplace_dygraph(self):
pass
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -106,6 +106,20 @@ class TestTakeAlongAxisAPI(unittest.TestCase): ...@@ -106,6 +106,20 @@ class TestTakeAlongAxisAPI(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
class TestTakeAlongAxisAPICase1(TestTakeAlongAxisAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 2]
self.index_shape = [4, 2]
self.index_np = np.array(
[[0, 0], [1, 0], [0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 0
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -2751,6 +2751,31 @@ def moveaxis(x, source, destination, name=None): ...@@ -2751,6 +2751,31 @@ def moveaxis(x, source, destination, name=None):
return out return out
def non_negative_axis(arr, axis):
ndim = len(arr.shape)
if axis >= 0:
assert axis < ndim, "'axis' must be in the range of [-{0}, {0})".format(
ndim)
else:
assert axis >= -ndim, "'axis' must be in the range of [-{0}, {0})".format(
ndim)
axis += ndim
return axis
def infer_broadcast_shape(arr, indices, axis):
# This function is used in take/put_along_axis
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = list(indices.shape)[axis]
broadcast_shape = tuple(broadcast_shape_list)
for i in range(len(arr.shape)):
if arr.shape[i] < indices.shape[i]:
# if indices matrix has larger size than arr matrix, do not broadcast.
return None
return broadcast_shape
def take_along_axis(arr, indices, axis): def take_along_axis(arr, indices, axis):
""" """
Take values from the input array by given indices matrix along the designated axis. Take values from the input array by given indices matrix along the designated axis.
...@@ -2779,14 +2804,20 @@ def take_along_axis(arr, indices, axis): ...@@ -2779,14 +2804,20 @@ def take_along_axis(arr, indices, axis):
print(result) print(result)
# [[1, 2, 3]] # [[1, 2, 3]]
""" """
if (arr.shape == indices.shape): if (len(arr.shape) != len(indices.shape)):
broadcast_shape = arr.shape raise ValueError(
else: "`indices` and `arr` must have the same number of dimensions!")
broadcast_shape_list = list(arr.shape) axis = non_negative_axis(arr, axis)
broadcast_shape_list[axis] = 1 broadcast_shape = infer_broadcast_shape(arr, indices, axis)
broadcast_shape = tuple(broadcast_shape_list) if not broadcast_shape:
# if indices matrix have larger size than arr, arr should broadcast into indices shape.
broadcast_shape = indices.shape
if in_dygraph_mode(): if in_dygraph_mode():
indices = paddle.broadcast_to(indices, broadcast_shape) indices = paddle.broadcast_to(indices, broadcast_shape)
broadcast_shape_list = list(broadcast_shape)
broadcast_shape_list[axis] = list(arr.shape)[axis]
broadcast_shape = tuple(broadcast_shape_list)
arr = paddle.broadcast_to(arr, broadcast_shape)
return _C_ops.take_along_axis(arr, indices, 'Axis', axis) return _C_ops.take_along_axis(arr, indices, 'Axis', axis)
check_variable_and_dtype( check_variable_and_dtype(
arr, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], arr, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
...@@ -2794,6 +2825,10 @@ def take_along_axis(arr, indices, axis): ...@@ -2794,6 +2825,10 @@ def take_along_axis(arr, indices, axis):
check_variable_and_dtype(indices, 'index', ['int32', 'int64'], check_variable_and_dtype(indices, 'index', ['int32', 'int64'],
'take_along_axis') 'take_along_axis')
indices = paddle.broadcast_to(indices, broadcast_shape) indices = paddle.broadcast_to(indices, broadcast_shape)
broadcast_shape_list = list(broadcast_shape)
broadcast_shape_list[axis] = list(arr.shape)[axis]
broadcast_shape = tuple(broadcast_shape_list)
arr = paddle.broadcast_to(arr, broadcast_shape)
helper = LayerHelper('take_along_axis', **locals()) helper = LayerHelper('take_along_axis', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
result = helper.create_variable_for_type_inference(dtype) result = helper.create_variable_for_type_inference(dtype)
...@@ -2837,17 +2872,17 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): ...@@ -2837,17 +2872,17 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
# [60, 40, 50]] # [60, 40, 50]]
""" """
if (arr.shape == indices.shape): if (len(arr.shape) != len(indices.shape)):
broadcast_shape = arr.shape raise ValueError(
else: "`indices` and `arr` must have the same number of dimensions!")
broadcast_shape_list = list(arr.shape) axis = non_negative_axis(arr, axis)
broadcast_shape_list[axis] = 1 broadcast_shape = infer_broadcast_shape(arr, indices, axis)
broadcast_shape = tuple(broadcast_shape_list)
if in_dygraph_mode(): if in_dygraph_mode():
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.to_tensor(values) if not isinstance( values = paddle.to_tensor(values) if not isinstance(
values, paddle.Tensor) else values values, paddle.Tensor) else values
values = paddle.broadcast_to(values, broadcast_shape) if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape)
return _C_ops.put_along_axis(arr, indices, values, "Axis", axis, return _C_ops.put_along_axis(arr, indices, values, "Axis", axis,
"Reduce", reduce) "Reduce", reduce)
...@@ -2856,8 +2891,9 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): ...@@ -2856,8 +2891,9 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
'put_along_axis') 'put_along_axis')
check_variable_and_dtype(indices, 'index', ['int32', 'int64'], check_variable_and_dtype(indices, 'index', ['int32', 'int64'],
'put_along_axis') 'put_along_axis')
if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape) indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, broadcast_shape) values = paddle.broadcast_to(values, indices.shape)
helper = LayerHelper('put_along_axis', **locals()) helper = LayerHelper('put_along_axis', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
result = helper.create_variable_for_type_inference(dtype) result = helper.create_variable_for_type_inference(dtype)
...@@ -2875,19 +2911,18 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): ...@@ -2875,19 +2911,18 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
@inplace_apis_in_dygraph_only @inplace_apis_in_dygraph_only
def put_along_axis_(arr, indices, values, axis, reduce='assign'): def put_along_axis_(arr, indices, values, axis, reduce='assign'):
r""" r"""
Inplace version of ``put_along_axis`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``put_along_axis`` API, the output Tensor will be inplaced with input ``arr``.
Please refer to :ref:`api_tensor_put_along_axis`. Please refer to :ref:`api_tensor_put_along_axis`.
""" """
if (arr.shape == indices.shape): if (len(arr.shape) != len(indices.shape)):
broadcast_shape = arr.shape raise ValueError(
else: "`indices` and `arr` must have the same number of dimensions!")
broadcast_shape_list = list(arr.shape) axis = non_negative_axis(arr, axis)
broadcast_shape_list[axis] = 1 broadcast_shape = infer_broadcast_shape(arr, indices, axis)
broadcast_shape = tuple(broadcast_shape_list)
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.to_tensor(values) if not isinstance( values = paddle.to_tensor(values) if not isinstance(
values, paddle.Tensor) else values values, paddle.Tensor) else values
values = paddle.broadcast_to(values, broadcast_shape) if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape)
return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce", return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce",
reduce) reduce)
...@@ -437,17 +437,29 @@ def quantile(x, q, axis=None, keepdim=False): ...@@ -437,17 +437,29 @@ def quantile(x, q, axis=None, keepdim=False):
indices_upper = paddle.ceil(indices).astype(paddle.int32) indices_upper = paddle.ceil(indices).astype(paddle.int32)
outputs = [] outputs = []
def expand_dim(indices, sorted_tensor_shape, axis):
assert axis < len(list(sorted_tensor_shape))
expanded_shape = [1] * len(list(sorted_tensor_shape))
expanded_shape[axis] = len(indices)
expanded_shape = tuple(expanded_shape)
indices = indices.reshape(expanded_shape)
return indices
# TODO(chenjianye): replace the for-loop to directly take elements. # TODO(chenjianye): replace the for-loop to directly take elements.
for i in range(len(indices)): for i in range(len(indices)):
if (indices_upper[i] != indices_below[i]): if (indices_upper[i] != indices_below[i]):
tensor_below = paddle.take_along_axis(sorted_tensor, tensor_below = paddle.take_along_axis(
indices_below[i], axis) sorted_tensor,
tensor_upper = paddle.take_along_axis(sorted_tensor, expand_dim(indices_below[i], sorted_tensor.shape, axis), axis)
indices_upper[i], axis) tensor_upper = paddle.take_along_axis(
sorted_tensor,
expand_dim(indices_upper[i], sorted_tensor.shape, axis), axis)
weights = (indices[i] - indices_below[i]).astype(x.dtype) weights = (indices[i] - indices_below[i]).astype(x.dtype)
out = paddle.lerp(tensor_below, tensor_upper, weights) out = paddle.lerp(tensor_below, tensor_upper, weights)
else: else:
out = paddle.take_along_axis(sorted_tensor, indices_below[i], axis) out = paddle.take_along_axis(
sorted_tensor,
expand_dim(indices_below[i], sorted_tensor.shape, axis), axis)
if not keepdim: if not keepdim:
out = paddle.squeeze(out, axis=axis) out = paddle.squeeze(out, axis=axis)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册