diff --git a/paddle/fluid/operators/diag_v2_op.cc b/paddle/fluid/operators/diag_v2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..67dc2843345682b2dfe3d568e452461860575544 --- /dev/null +++ b/paddle/fluid/operators/diag_v2_op.cc @@ -0,0 +1,140 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/diag_v2_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +class DiagV2Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "diag_v2"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "diag_v2"); + + auto x_dims = ctx->GetInputDim("X"); + auto offset = ctx->Attrs().Get("offset"); + + if (x_dims.size() == 1UL) { + int64_t size = x_dims[0] + std::abs(offset); + ctx->SetOutputDim("Out", {size, size}); + } else if (x_dims.size() == 2UL) { + int64_t size; + if (offset >= 0) { + size = std::min(x_dims[0], x_dims[1] - offset); + } else { + size = std::min(x_dims[0] + offset, x_dims[1]); + } + ctx->SetOutputDim("Out", {size}); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The input tensor X's dimensions of DiagV2Op should be either 1 or " + "2, but received %d.", + x_dims.size())); + } + } +}; + +class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor. Its shape is either 1-D or 2-D."); + AddOutput("Out", "The output tensor. A square matrix or a vector."); + AddAttr("offset", + "The diagonal offset. A positive value represents " + "superdiagonal, 0 represents the main diagonal, and a " + "negative value represents subdiagonal.") + .SetDefault(0); + AddAttr("padding_value", + "Use this value to fill the area outside the specified " + "diagonal band. Only takes effect when the input is a 1-D " + "Tensor. The default value is 0.") + .SetDefault(0.0f); + AddComment(R"DOC( + If ``x`` is a vector (1-D tensor), a 2-D square tensor whth the elements of ``x`` as the diagonal is returned. + + If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal elements of ``x`` is returned. + + The argument ``offset`` controls the diagonal offset: + + If ``offset`` = 0, it is the main diagonal. + + If ``offset`` > 0, it is superdiagonal. + + If ``offset`` < 0, it is subdiagonal. +)DOC"); + } +}; + +template +class DiagV2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* x_data = X->data(); + auto x_dims = X->dims(); + int offset = context.Attr("offset"); + auto* out = context.Output("Out"); + T* out_data = out->mutable_data(context.GetPlace()); + auto out_dims = out->dims(); + + int64_t i; + if (x_dims.size() == 1) { + float padding_value = context.Attr("padding_value"); + math::SetConstant set_padding_value; + auto& dev_ctx = context.template device_context(); + set_padding_value(dev_ctx, out, static_cast(padding_value)); + + auto x_length = x_dims[0]; + const int& x_stride = ComputeStride(0, x_dims); + + auto out_stride_0 = ComputeStride(0, out_dims); + auto out_stride_1 = ComputeStride(1, out_dims); + out_data += + (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0); + + for (i = 0; i < x_length; i++) { + out_data[i * (out_stride_0 + out_stride_1)] = x_data[i * x_stride]; + } + } else { + auto out_length = out_dims[0]; + const int& x_stride_0 = ComputeStride(0, x_dims); + const int& x_stride_1 = ComputeStride(1, x_dims); + + auto out_stride_0 = ComputeStride(0, out_dims); + x_data += (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0); + for (i = 0; i < out_length; i++) { + out_data[i * out_stride_0] = x_data[i * (x_stride_0 + x_stride_1)]; + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + diag_v2, ops::DiagV2Kernel, + ops::DiagV2Kernel, + ops::DiagV2Kernel, + ops::DiagV2Kernel); diff --git a/paddle/fluid/operators/diag_v2_op.cu b/paddle/fluid/operators/diag_v2_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..4386cc6b8183c03b4d4a19aba7d1126eac2ab495 --- /dev/null +++ b/paddle/fluid/operators/diag_v2_op.cu @@ -0,0 +1,122 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/diag_v2_op.h" + +namespace paddle { +namespace operators { + +// Extract the diagonal of a matrix 'x' to a vector 'out'. +template +__global__ void ExtractDiagonalKernel(T* out, const T* x, std::ptrdiff_t start, + std::ptrdiff_t size, + const std::ptrdiff_t sumStride, + const std::ptrdiff_t outStride) { + for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + const std::ptrdiff_t xOffset = start + sumStride * idx; + out[outStride * idx] = x[xOffset]; + } +} + +// Paste a vector 'x' to the diagonal of a matrix 'out' +template +__global__ void PasteDiagonalKernel(T* out, const T* x, std::ptrdiff_t start, + std::ptrdiff_t x_length, + const std::ptrdiff_t sumStride, + const std::ptrdiff_t xStride) { + for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < x_length; idx += gridDim.x * blockDim.x) { + const std::ptrdiff_t outOffset = start + sumStride * idx; + out[outOffset] = x[xStride * idx]; + } +} + +template +class DiagV2CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* x_data = X->data(); + auto x_dims = X->dims(); + int offset = context.Attr("offset"); + auto* out = context.Output("Out"); + T* out_data = out->mutable_data(context.GetPlace()); + auto out_dims = out->dims(); + auto& dev_ctx = context.template device_context(); + + if (x_dims.size() == 1) { + float padding_value = context.Attr("padding_value"); + math::SetConstant set_padding_value; + set_padding_value(dev_ctx, out, static_cast(padding_value)); + + auto x_length = x_dims[0]; + auto size = (offset > 0) ? x_length + offset : x_length - offset; + const int& x_stride = ComputeStride(0, x_dims); + if (size > 0) { + const int block_num = std::min(static_cast(size), + dev_ctx.GetMaxPhysicalThreadCount()); + int size_ = static_cast(size); + int block_num_ = static_cast(block_num); + const int grid_num = + std::min(1024, (size_ + block_num_ - 1) / block_num_); + const auto& out_stride_0 = ComputeStride(0, out_dims); + const auto& out_stride_1 = ComputeStride(1, out_dims); + auto start = + (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0); + + PasteDiagonalKernel<<>>( + out_data, x_data, start, x_length, out_stride_0 + out_stride_1, + x_stride); + } + } else { + const int& x_stride_0 = ComputeStride(0, x_dims); + const int& x_stride_1 = ComputeStride(1, x_dims); + + int size; + if (offset > 0) { + size = std::min(x_dims[0], x_dims[1] - offset); + } else { + size = std::min(x_dims[0] + offset, x_dims[1]); + } + + if (size > 0) { + const int block_num = std::min(static_cast(size), + dev_ctx.GetMaxPhysicalThreadCount()); + int size_ = static_cast(size); + int block_num_ = static_cast(block_num); + const int grid_num = + std::min(1024, (size_ + block_num_ - 1) / block_num_); + auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0); + const auto& out_stride_0 = ComputeStride(0, out_dims); + + ExtractDiagonalKernel<<>>( + out_data, x_data, start, size, x_stride_0 + x_stride_1, + out_stride_0); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + diag_v2, ops::DiagV2CUDAKernel, + ops::DiagV2CUDAKernel, + ops::DiagV2CUDAKernel, + ops::DiagV2CUDAKernel); diff --git a/paddle/fluid/operators/diag_v2_op.h b/paddle/fluid/operators/diag_v2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7850def06117ff4232afe4fca95a3e3e500e876d --- /dev/null +++ b/paddle/fluid/operators/diag_v2_op.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using DDim = framework::DDim; + +static inline int ComputeStride(int axis, DDim dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 5ce9d56cb0bd7244068684a379fd673a6107d50b..77a78eb4a14a0a5ad9be9cff71131ca473106ab8 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1559,6 +1559,7 @@ def zeros_like(x, out=None): return out +@deprecated(since="2.0.0", update_to="paddle.diag") def diag(diagonal): """ :alias_main: paddle.diag diff --git a/python/paddle/fluid/tests/unittests/test_diag.py b/python/paddle/fluid/tests/unittests/test_diag.py index b6566676d2533aad5272fe61dbedbc1d55ea213b..8bf40459902e09f19a5badce62084841a0a23619 100644 --- a/python/paddle/fluid/tests/unittests/test_diag.py +++ b/python/paddle/fluid/tests/unittests/test_diag.py @@ -17,11 +17,173 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid import Program, program_guard +class TestDiagV2Op(OpTest): + def setUp(self): + self.op_type = "diag_v2" + self.x = np.random.rand(10, 10) + self.offset = 0 + self.padding_value = 0.0 + self.out = np.diag(self.x, self.offset) + + self.init_config() + self.inputs = {'X': self.x} + self.attrs = { + 'offset': self.offset, + 'padding_value': self.padding_value + } + self.outputs = {'Out': self.out} + + def test_check_output(self): + self.check_output() + + def init_config(self): + pass + + +class TestDiagV2OpCase1(TestDiagV2Op): + def init_config(self): + self.offset = 1 + self.out = np.diag(self.x, self.offset) + + +class TestDiagV2OpCase2(TestDiagV2Op): + def init_config(self): + self.offset = -1 + self.out = np.diag(self.x, self.offset) + + +class TestDiagV2OpCase3(TestDiagV2Op): + def init_config(self): + self.x = np.random.randint(-10, 10, size=(10, 10)) + self.out = np.diag(self.x, self.offset) + + +class TestDiagV2OpCase4(TestDiagV2Op): + def init_config(self): + self.x = np.random.rand(100) + self.padding_value = 8 + n = self.x.size + self.out = self.padding_value * np.ones((n, n)) + np.diag( + self.x, self.offset) - np.diag(self.padding_value * np.ones(n)) + + +class TestDiagV2Error(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_diag_v2_type(): + x = [1, 2, 3] + output = paddle.diag(x) + + self.assertRaises(TypeError, test_diag_v2_type) + + +class TestDiagV2API(unittest.TestCase): + def setUp(self): + self.input_np = np.random.random(size=(10, 10)).astype(np.float32) + self.expected0 = np.diag(self.input_np) + self.expected1 = np.diag(self.input_np, k=1) + self.expected2 = np.diag(self.input_np, k=-1) + + self.input_np2 = np.random.rand(100) + self.offset = 0 + self.padding_value = 8 + n = self.input_np2.size + self.expected3 = self.padding_value * np.ones( + (n, n)) + np.diag(self.input_np2, self.offset) - np.diag( + self.padding_value * np.ones(n)) + + self.input_np3 = np.random.randint(-10, 10, size=(100)).astype(np.int64) + self.padding_value = 8.0 + n = self.input_np3.size + self.expected4 = self.padding_value * np.ones( + (n, n)) + np.diag(self.input_np3, self.offset) - np.diag( + self.padding_value * np.ones(n)) + + self.padding_value = -8 + self.expected5 = self.padding_value * np.ones( + (n, n)) + np.diag(self.input_np3, self.offset) - np.diag( + self.padding_value * np.ones(n)) + + def run_imperative(self): + x = paddle.to_tensor(self.input_np) + y = paddle.diag(x) + self.assertTrue(np.allclose(y.numpy(), self.expected0)) + + y = paddle.diag(x, offset=1) + self.assertTrue(np.allclose(y.numpy(), self.expected1)) + + y = paddle.diag(x, offset=-1) + self.assertTrue(np.allclose(y.numpy(), self.expected2)) + + x = paddle.to_tensor(self.input_np2) + y = paddle.diag(x, padding_value=8) + self.assertTrue(np.allclose(y.numpy(), self.expected3)) + + x = paddle.to_tensor(self.input_np3) + y = paddle.diag(x, padding_value=8.0) + self.assertTrue(np.allclose(y.numpy(), self.expected4)) + + y = paddle.diag(x, padding_value=-8) + self.assertTrue(np.allclose(y.numpy(), self.expected5)) + + def run_static(self, use_gpu=False): + x = paddle.data(name='input', shape=[10, 10], dtype='float32') + x2 = paddle.data(name='input2', shape=[100], dtype='float64') + x3 = paddle.data(name='input3', shape=[100], dtype='int64') + result0 = paddle.diag(x) + result1 = paddle.diag(x, offset=1) + result2 = paddle.diag(x, offset=-1) + result3 = paddle.diag(x, name='aaa') + result4 = paddle.diag(x2, padding_value=8) + result5 = paddle.diag(x3, padding_value=8.0) + result6 = paddle.diag(x3, padding_value=-8) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + res0, res1, res2, res4, res5, res6 = exe.run( + feed={ + "input": self.input_np, + "input2": self.input_np2, + 'input3': self.input_np3 + }, + fetch_list=[result0, result1, result2, result4, result5, result6]) + + self.assertTrue(np.allclose(res0, self.expected0)) + self.assertTrue(np.allclose(res1, self.expected1)) + self.assertTrue(np.allclose(res2, self.expected2)) + self.assertTrue('aaa' in result3.name) + self.assertTrue(np.allclose(res4, self.expected3)) + self.assertTrue(np.allclose(res5, self.expected4)) + self.assertTrue(np.allclose(res6, self.expected5)) + + def test_cpu(self): + paddle.disable_static(place=paddle.fluid.CPUPlace()) + self.run_imperative() + paddle.enable_static() + + with fluid.program_guard(fluid.Program()): + self.run_static() + + def test_gpu(self): + if not fluid.core.is_compiled_with_cuda(): + return + + paddle.disable_static(place=paddle.fluid.CUDAPlace(0)) + self.run_imperative() + paddle.enable_static() + + with fluid.program_guard(fluid.Program()): + self.run_static(use_gpu=True) + + class TestDiagOp(OpTest): def setUp(self): self.op_type = "diag" diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 416460d6b9d7ba267f327bfbfc5bc32331ee9d50..1911d8ccc25e01ee6419fd26126881304ab61f01 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -29,7 +29,6 @@ from paddle.common_ops_import import * # TODO: define functions to get create a tensor from ..fluid.layers import crop_tensor #DEFINE_ALIAS -from ..fluid.layers import diag #DEFINE_ALIAS from ..fluid.layers import fill_constant #DEFINE_ALIAS from ..fluid.layers import linspace #DEFINE_ALIAS import paddle @@ -903,3 +902,92 @@ def meshgrid(*args, **kwargs): type='meshgrid', inputs={'X': list(args)}, outputs={'Out': out}) return out + + +def diag(x, offset=0, padding_value=0, name=None): + """ + If ``x`` is a vector (1-D tensor), a 2-D square tensor whth the elements of ``x`` as the diagonal is returned. + + If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal elements of ``x`` is returned. + + The argument ``offset`` controls the diagonal offset: + + If ``offset`` = 0, it is the main diagonal. + + If ``offset`` > 0, it is superdiagonal. + + If ``offset`` < 0, it is subdiagonal. + + Args: + x (Tensor): The input tensor. Its shape is either 1-D or 2-D. Its data type should be float32, float64, int32, int64. + offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal. + padding_value (int|float, optional): Use this value to fill the area outside the specified diagonal band. Only takes effect when the input is a 1-D Tensor. The default value is 0. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, a square matrix or a vector. The output data type is the same as input data type. + + Examples: + .. code-block:: python + + import paddle + + paddle.disable_static() + x = paddle.to_tensor([1, 2, 3]) + y = paddle.diag(x) + print(y.numpy()) + # [[1 0 0] + # [0 2 0] + # [0 0 3]] + + y = paddle.diag(x, offset=1) + print(y.numpy()) + # [[0 1 0 0] + # [0 0 2 0] + # [0 0 0 3] + # [0 0 0 0]] + + y = paddle.diag(x, padding_value=6) + print(y.numpy()) + # [[1 6 6] + # [6 2 6] + # [6 6 3]] + + .. code-block:: python + + import paddle + + paddle.disable_static() + x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + y = paddle.diag(x) + print(y.numpy()) + # [1 5] + + y = paddle.diag(x, offset=1) + print(y.numpy()) + # [2 6] + + y = paddle.diag(x, offset=-1) + print(y.numpy()) + # [4] + """ + if in_dygraph_mode(): + return core.ops.diag_v2(x, "offset", offset, "padding_value", + padding_value) + + check_type(x, 'x', (Variable), 'diag_v2') + check_dtype(x.dtype, 'x', ['float32', 'float64', 'int32', 'int64'], + 'diag_v2') + helper = LayerHelper("diag_v2", **locals()) + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='diag_v2', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'offset': offset, + 'padding_value': padding_value}) + + out.stop_gradient = True + return out