未验证 提交 32ef95d7 编写于 作者: L Li Min 提交者: GitHub

Add diagflat op, test=develop (#33334)

上级 555c3463
......@@ -72,6 +72,7 @@ from .tensor.attribute import real # noqa: F401
from .tensor.attribute import imag # noqa: F401
from .tensor.creation import to_tensor # noqa: F401
from .tensor.creation import diag # noqa: F401
from .tensor.creation import diagflat # noqa: F401
from .tensor.creation import eye # noqa: F401
from .tensor.creation import linspace # noqa: F401
from .tensor.creation import ones # noqa: F401
......@@ -301,6 +302,7 @@ __all__ = [ #noqa
'add',
'subtract',
'diag',
'diagflat',
'isnan',
'scatter_nd_add',
'unstack',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.static import Program, program_guard
class TestDiagFlatError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
def test_diagflat_type():
x = [1, 2, 3]
output = paddle.diagflat(x)
self.assertRaises(TypeError, test_diagflat_type)
x = paddle.static.data('data', [3, 3])
self.assertRaises(TypeError, paddle.diagflat, x, offset=2.5)
class TestDiagFlatAPI(unittest.TestCase):
def setUp(self):
self.input_np = np.random.random(size=(10, 10)).astype(np.float64)
self.expected0 = np.diagflat(self.input_np)
self.expected1 = np.diagflat(self.input_np, k=1)
self.expected2 = np.diagflat(self.input_np, k=-1)
self.input_np2 = np.random.random(size=(20)).astype(np.float64)
self.expected3 = np.diagflat(self.input_np2)
self.expected4 = np.diagflat(self.input_np2, k=1)
self.expected5 = np.diagflat(self.input_np2, k=-1)
def run_imperative(self):
x = paddle.to_tensor(self.input_np)
y = paddle.diagflat(x)
self.assertTrue(np.allclose(y.numpy(), self.expected0))
y = paddle.diagflat(x, offset=1)
self.assertTrue(np.allclose(y.numpy(), self.expected1))
y = paddle.diagflat(x, offset=-1)
self.assertTrue(np.allclose(y.numpy(), self.expected2))
x = paddle.to_tensor(self.input_np2)
y = paddle.diagflat(x)
self.assertTrue(np.allclose(y.numpy(), self.expected3))
y = paddle.diagflat(x, offset=1)
self.assertTrue(np.allclose(y.numpy(), self.expected4))
y = paddle.diagflat(x, offset=-1)
self.assertTrue(np.allclose(y.numpy(), self.expected5))
def run_static(self, use_gpu=False):
x = paddle.static.data(name='input', shape=[10, 10], dtype='float64')
x2 = paddle.static.data(name='input2', shape=[20], dtype='float64')
result0 = paddle.diagflat(x)
result3 = paddle.diagflat(x2)
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
res0, res3 = exe.run(
feed={"input": self.input_np,
'input2': self.input_np2},
fetch_list=[result0, result3])
self.assertTrue(np.allclose(res0, self.expected0))
self.assertTrue(np.allclose(res3, self.expected3))
def test_cpu(self):
paddle.disable_static(place=paddle.CPUPlace())
self.run_imperative()
paddle.enable_static()
with paddle.static.program_guard(Program()):
self.run_static()
def test_gpu(self):
if not paddle.is_compiled_with_cuda():
return
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative()
paddle.enable_static()
with paddle.static.program_guard(Program()):
self.run_static(use_gpu=True)
if __name__ == "__main__":
unittest.main()
......@@ -18,6 +18,7 @@ from .attribute import real # noqa: F401
from .attribute import imag # noqa: F401
from .creation import to_tensor # noqa: F401
from .creation import diag # noqa: F401
from .creation import diagflat # noqa: F401
from .creation import eye # noqa: F401
from .creation import linspace # noqa: F401
from .creation import ones # noqa: F401
......
......@@ -772,6 +772,131 @@ def meshgrid(*args, **kwargs):
return out
def diagflat(x, offset=0, name=None):
"""
If ``x`` is a vector (1-D tensor), a 2-D square tensor whth the elements of ``x`` as the diagonal is returned.
If ``x`` is a tensor (more than 1-D), a 2-D square tensor with the elements of flattened ``x`` as the diagonal is returned.
The argument ``offset`` controls the diagonal offset.
If ``offset`` = 0, it is the main diagonal.
If ``offset`` > 0, it is superdiagonal.
If ``offset`` < 0, it is subdiagonal.
Args:
x (Tensor): The input tensor. It can be any shape. Its data type should be float32, float64, int32, int64.
offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal. Default: 0 (main diagonal).
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, a square matrix. The output data type is the same as input data type.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1, 2, 3])
y = paddle.diagflat(x)
print(y.numpy())
# [[1 0 0]
# [0 2 0]
# [0 0 3]]
y = paddle.diagflat(x, offset=1)
print(y.numpy())
# [[0 1 0 0]
# [0 0 2 0]
# [0 0 0 3]
# [0 0 0 0]]
y = paddle.diagflat(x, offset=-1)
print(y.numpy())
# [[0 0 0 0]
# [1 0 0 0]
# [0 2 0 0]
# [0 0 3 0]]
.. code-block:: python
import paddle
x = paddle.to_tensor([[1, 2], [3, 4]])
y = paddle.diagflat(x)
print(y.numpy())
# [[1 0 0 0]
# [0 2 0 0]
# [0 0 3 0]
# [0 0 0 4]]
y = paddle.diagflat(x, offset=1)
print(y.numpy())
# [[0 1 0 0 0]
# [0 0 2 0 0]
# [0 0 0 3 0]
# [0 0 0 0 4]
# [0 0 0 0 0]]
y = paddle.diagflat(x, offset=-1)
print(y.numpy())
# [[0 0 0 0 0]
# [1 0 0 0 0]
# [0 2 0 0 0]
# [0 0 3 0 0]
# [0 0 0 4 0]]
"""
padding_value = 0
if in_dygraph_mode():
if len(x.shape) == 1:
return core.ops.diag_v2(x, "offset", offset, "padding_value",
padding_value)
else:
y, _ = core.ops.flatten_contiguous_range(x, "start_axis", 0,
"stop_axis", -1)
return core.ops.diag_v2(y, "offset", offset, "padding_value",
padding_value)
check_type(x, 'x', (Variable), 'diagflat')
check_dtype(x.dtype, 'x', ['float32', 'float64', 'int32', 'int64'],
'diagflat')
check_type(offset, 'offset', (int), 'diagflat')
helper = LayerHelper("diagflat", **locals())
out1 = helper.create_variable_for_type_inference(dtype=x.dtype)
out1_shape = helper.create_variable_for_type_inference(x.dtype)
out2 = helper.create_variable_for_type_inference(dtype=x.dtype)
if len(x.shape) == 1:
helper.append_op(
type='diag_v2',
inputs={'X': x},
outputs={'Out': out2},
attrs={'offset': offset,
'padding_value': padding_value})
else:
helper.append_op(
type='flatten_contiguous_range',
inputs={'X': x},
outputs={'Out': out1,
'XShape': out1_shape},
attrs={'start_axis': 0,
'stop_axis': -1})
out1.stop_gradient = True
helper.append_op(
type='diag_v2',
inputs={'X': out1},
outputs={'Out': out2},
attrs={'offset': offset,
'padding_value': padding_value})
out2.stop_gradient = True
return out2
def diag(x, offset=0, padding_value=0, name=None):
"""
If ``x`` is a vector (1-D tensor), a 2-D square tensor whth the elements of ``x`` as the diagonal is returned.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册