未验证 提交 84e5d099 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

add new API:paddle.moveaxis/Tensor.moveaxis (#37833)

* add new API:paddle.movedim/moveaxis

* add new API:paddle.movedim/moveaxis

* add new API:add new API:paddle.movedim/moveaxis

* fix comment

* fix comment
上级 038ca68d
...@@ -158,7 +158,7 @@ from .tensor.manipulation import tolist # noqa: F401 ...@@ -158,7 +158,7 @@ from .tensor.manipulation import tolist # noqa: F401
from .tensor.manipulation import tensordot # noqa: F401 from .tensor.manipulation import tensordot # noqa: F401
from .tensor.manipulation import as_complex # noqa: F401 from .tensor.manipulation import as_complex # noqa: F401
from .tensor.manipulation import as_real # noqa: F401 from .tensor.manipulation import as_real # noqa: F401
from .tensor.manipulation import moveaxis # noqa: F401
from .tensor.math import abs # noqa: F401 from .tensor.math import abs # noqa: F401
from .tensor.math import acos # noqa: F401 from .tensor.math import acos # noqa: F401
from .tensor.math import asin # noqa: F401 from .tensor.math import asin # noqa: F401
...@@ -568,4 +568,5 @@ __all__ = [ # noqa ...@@ -568,4 +568,5 @@ __all__ = [ # noqa
'as_real', 'as_real',
'diff', 'diff',
'angle', 'angle',
'moveaxis',
] ]
...@@ -348,5 +348,77 @@ class TestTAPI(unittest.TestCase): ...@@ -348,5 +348,77 @@ class TestTAPI(unittest.TestCase):
self.assertRaises(ValueError, test_x_dimension_check) self.assertRaises(ValueError, test_x_dimension_check)
class TestMoveAxis(unittest.TestCase):
def test_moveaxis1(self):
x_np = np.random.randn(2, 3, 4, 5, 7)
expected = np.moveaxis(x_np, [0, 4, 3, 2], [1, 3, 2, 0])
paddle.enable_static()
with paddle.static.program_guard(fluid.Program()):
x = paddle.static.data("x", shape=[2, 3, 4, 5, 7], dtype='float64')
out = paddle.moveaxis(x, [0, 4, 3, 2], [1, 3, 2, 0])
exe = paddle.static.Executor()
out_np = exe.run(feed={"x": x_np}, fetch_list=[out])[0]
self.assertEqual(np.array_equal(out_np, expected), True)
paddle.disable_static()
x = paddle.to_tensor(x_np)
out = paddle.moveaxis(x, [0, 4, 3, 2], [1, 3, 2, 0])
self.assertEqual(out.shape, [4, 2, 5, 7, 3])
self.assertEqual(np.array_equal(out.numpy(), expected), True)
paddle.enable_static()
def test_moveaxis2(self):
x_np = np.random.randn(2, 3, 5)
expected = np.moveaxis(x_np, -2, -1)
paddle.enable_static()
with paddle.static.program_guard(fluid.Program()):
x = paddle.static.data("x", shape=[2, 3, 5], dtype='float64')
out = x.moveaxis(-2, -1)
exe = paddle.static.Executor()
out_np = exe.run(feed={"x": x_np}, fetch_list=[out])[0]
self.assertEqual(np.array_equal(out_np, expected), True)
paddle.disable_static()
x = paddle.to_tensor(x_np)
out = x.moveaxis(-2, -1)
self.assertEqual(out.shape, [2, 5, 3])
self.assertEqual(np.array_equal(out.numpy(), expected), True)
paddle.enable_static()
def test_error(self):
x = paddle.randn([2, 3, 4, 5])
# src must have the same number with dst
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [1, 0], [2])
# each element of src must be unique
with self.assertRaises(ValueError):
paddle.moveaxis(x, [1, 1], [0, 2])
# each element of dst must be unique
with self.assertRaises(ValueError):
paddle.moveaxis(x, [0, 1], [2, 2])
# each element of src must be integer
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [0.5], [1])
# each element of dst must be integer
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [0], [1.5])
# each element of src must be in the range of [-4, 3)
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [-10, 1], [2, 3])
# each element of dst must be in the range of [-4, 3)
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [2, 1], [10, 3])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -113,7 +113,7 @@ from .manipulation import chunk # noqa: F401 ...@@ -113,7 +113,7 @@ from .manipulation import chunk # noqa: F401
from .manipulation import tensordot # noqa: F401 from .manipulation import tensordot # noqa: F401
from .manipulation import as_complex # noqa: F401 from .manipulation import as_complex # noqa: F401
from .manipulation import as_real # noqa: F401 from .manipulation import as_real # noqa: F401
from .manipulation import moveaxis # noqa: F401
from .math import abs # noqa: F401 from .math import abs # noqa: F401
from .math import acos # noqa: F401 from .math import acos # noqa: F401
from .math import asin # noqa: F401 from .math import asin # noqa: F401
...@@ -426,6 +426,7 @@ tensor_method_func = [ #noqa ...@@ -426,6 +426,7 @@ tensor_method_func = [ #noqa
'lerp', 'lerp',
'lerp_', 'lerp_',
'angle', 'angle',
'moveaxis'
] ]
#this list used in math_op_patch.py for magic_method bind #this list used in math_op_patch.py for magic_method bind
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
from collections import Counter
from ..fluid.layers import core from ..fluid.layers import core
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
...@@ -2579,3 +2580,103 @@ def as_real(x, name=None): ...@@ -2579,3 +2580,103 @@ def as_real(x, name=None):
outputs = {"Out": out} outputs = {"Out": out}
helper.append_op(type=op_type, inputs=inputs, outputs=outputs) helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out return out
def moveaxis(x, source, destination, name=None):
"""
Move the axis of tensor from ``source`` position to ``destination`` position.
Other axis that have not been moved remain their original order.
Args:
x (Tensor): The input Tensor. It is a N-D Tensor of data types bool, int32, int64, float32, float64, complex64, complex128.
source(int|tuple|list): ``source`` position of axis that will be moved. Each element must be unique and integer.
destination(int|tuple|list(int)): ``destination`` position of axis that has been moved. Each element must be unique and integer.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: A new tensor whose axis have been moved.
Examples:
.. code-block:: python
import paddle
x = paddle.ones([3, 2, 4])
paddle.moveaxis(x, [0, 1], [1, 2]).shape
# [4, 3, 2]
x = paddle.ones([2, 3])
paddle.moveaxis(x, 0, 1) # equivalent to paddle.t(x)
# [3, 2]
"""
src = [source] if isinstance(source, int) else source
dst = [destination] if isinstance(destination, int) else destination
assert len(src) == len(
dst), "'source' must have the same number with 'destination'"
count = Counter(src).most_common(1)
if count[0][1] > 1:
raise ValueError("Each elemment of 'source' must be unique!")
count = Counter(dst).most_common(1)
if count[0][1] > 1:
raise ValueError("Each elemment of 'destination' must be unique!")
ndim = len(x.shape)
# perm is the new order after move axis
perm = list(range(ndim))
src_dims = list(range(ndim))
dst_dims = list(range(ndim))
for i, axis in enumerate(zip(src, dst)):
assert isinstance(axis[0],
int), "Each elemment of 'source' must be integer."
if axis[0] < 0:
assert axis[
0] >= -ndim, "'source' must be in the range of [-{0}, {0})".format(
ndim)
src[i] += ndim
else:
assert axis[
0] < ndim, "'source' must be in the range of [-{0}, {0})".format(
ndim)
assert isinstance(axis[1],
int), "Each elemment of 'source' must be integer."
if axis[1] < 0:
assert axis[
1] >= -ndim, "'source' must be in the range of [-{0}, {0})".format(
ndim)
dst[i] += ndim
else:
assert axis[
1] < ndim, "'source' must be in the range of [-{0}, {0})".format(
ndim)
perm[dst[i]] = src[i]
src_dims.remove(src[i])
dst_dims.remove(dst[i])
for i in range(len(src_dims)):
perm[dst_dims[i]] = src_dims[i]
if in_dygraph_mode():
out, _ = _C_ops.transpose2(x, 'axis', perm)
return out
check_variable_and_dtype(
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'moveaxis')
helper = LayerHelper('moveaxis', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
x_shape = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='transpose2',
inputs={'X': [x]},
outputs={'Out': [out],
'XShape': [x_shape]},
attrs={'axis': perm})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册