未验证 提交 4e0c6d91 编写于 作者: L LutaoChu 提交者: GitHub

add paddle.tensor.linalg.diag API, diag_v2 OP and CUDA kernel

add paddle.tensor.linalg.diag API, diag_v2 OP and CUDA kernel.
上级 f8863e06
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/diag_v2_op.h"
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class DiagV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "diag_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "diag_v2");
auto x_dims = ctx->GetInputDim("X");
auto offset = ctx->Attrs().Get<int>("offset");
if (x_dims.size() == 1UL) {
int64_t size = x_dims[0] + std::abs(offset);
ctx->SetOutputDim("Out", {size, size});
} else if (x_dims.size() == 2UL) {
int64_t size;
if (offset >= 0) {
size = std::min(x_dims[0], x_dims[1] - offset);
} else {
size = std::min(x_dims[0] + offset, x_dims[1]);
}
ctx->SetOutputDim("Out", {size});
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor X's dimensions of DiagV2Op should be either 1 or "
"2, but received %d.",
x_dims.size()));
}
}
};
class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor. Its shape is either 1-D or 2-D.");
AddOutput("Out", "The output tensor. A square matrix or a vector.");
AddAttr<int>("offset",
"The diagonal offset. A positive value represents "
"superdiagonal, 0 represents the main diagonal, and a "
"negative value represents subdiagonal.")
.SetDefault(0);
AddAttr<float>("padding_value",
"Use this value to fill the area outside the specified "
"diagonal band. Only takes effect when the input is a 1-D "
"Tensor. The default value is 0.")
.SetDefault(0.0f);
AddComment(R"DOC(
If ``x`` is a vector (1-D tensor), a 2-D square tensor whth the elements of ``x`` as the diagonal is returned.
If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal elements of ``x`` is returned.
The argument ``offset`` controls the diagonal offset:
If ``offset`` = 0, it is the main diagonal.
If ``offset`` > 0, it is superdiagonal.
If ``offset`` < 0, it is subdiagonal.
)DOC");
}
};
template <typename DeviceContext, typename T>
class DiagV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* x_data = X->data<T>();
auto x_dims = X->dims();
int offset = context.Attr<int>("offset");
auto* out = context.Output<framework::Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims();
int64_t i;
if (x_dims.size() == 1) {
float padding_value = context.Attr<float>("padding_value");
math::SetConstant<DeviceContext, T> set_padding_value;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));
auto x_length = x_dims[0];
const int& x_stride = ComputeStride(0, x_dims);
auto out_stride_0 = ComputeStride(0, out_dims);
auto out_stride_1 = ComputeStride(1, out_dims);
out_data +=
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);
for (i = 0; i < x_length; i++) {
out_data[i * (out_stride_0 + out_stride_1)] = x_data[i * x_stride];
}
} else {
auto out_length = out_dims[0];
const int& x_stride_0 = ComputeStride(0, x_dims);
const int& x_stride_1 = ComputeStride(1, x_dims);
auto out_stride_0 = ComputeStride(0, out_dims);
x_data += (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
for (i = 0; i < out_length; i++) {
out_data[i * out_stride_0] = x_data[i * (x_stride_0 + x_stride_1)];
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
diag_v2, ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diag_v2_op.h"
namespace paddle {
namespace operators {
// Extract the diagonal of a matrix 'x' to a vector 'out'.
template <typename T>
__global__ void ExtractDiagonalKernel(T* out, const T* x, std::ptrdiff_t start,
std::ptrdiff_t size,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t outStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t xOffset = start + sumStride * idx;
out[outStride * idx] = x[xOffset];
}
}
// Paste a vector 'x' to the diagonal of a matrix 'out'
template <typename T>
__global__ void PasteDiagonalKernel(T* out, const T* x, std::ptrdiff_t start,
std::ptrdiff_t x_length,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t xStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < x_length; idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t outOffset = start + sumStride * idx;
out[outOffset] = x[xStride * idx];
}
}
template <typename DeviceContext, typename T>
class DiagV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* x_data = X->data<T>();
auto x_dims = X->dims();
int offset = context.Attr<int>("offset");
auto* out = context.Output<framework::Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims();
auto& dev_ctx = context.template device_context<DeviceContext>();
if (x_dims.size() == 1) {
float padding_value = context.Attr<float>("padding_value");
math::SetConstant<DeviceContext, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));
auto x_length = x_dims[0];
auto size = (offset > 0) ? x_length + offset : x_length - offset;
const int& x_stride = ComputeStride(0, x_dims);
if (size > 0) {
const int block_num = std::min(static_cast<int>(size),
dev_ctx.GetMaxPhysicalThreadCount());
int size_ = static_cast<int>(size);
int block_num_ = static_cast<int>(block_num);
const int grid_num =
std::min(1024, (size_ + block_num_ - 1) / block_num_);
const auto& out_stride_0 = ComputeStride(0, out_dims);
const auto& out_stride_1 = ComputeStride(1, out_dims);
auto start =
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);
PasteDiagonalKernel<T><<<grid_num, block_num, 0, dev_ctx.stream()>>>(
out_data, x_data, start, x_length, out_stride_0 + out_stride_1,
x_stride);
}
} else {
const int& x_stride_0 = ComputeStride(0, x_dims);
const int& x_stride_1 = ComputeStride(1, x_dims);
int size;
if (offset > 0) {
size = std::min(x_dims[0], x_dims[1] - offset);
} else {
size = std::min(x_dims[0] + offset, x_dims[1]);
}
if (size > 0) {
const int block_num = std::min(static_cast<int>(size),
dev_ctx.GetMaxPhysicalThreadCount());
int size_ = static_cast<int>(size);
int block_num_ = static_cast<int>(block_num);
const int grid_num =
std::min(1024, (size_ + block_num_ - 1) / block_num_);
auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
const auto& out_stride_0 = ComputeStride(0, out_dims);
ExtractDiagonalKernel<T><<<grid_num, block_num, 0, dev_ctx.stream()>>>(
out_data, x_data, start, size, x_stride_0 + x_stride_1,
out_stride_0);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
diag_v2, ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using DDim = framework::DDim;
static inline int ComputeStride(int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
} // namespace operators
} // namespace paddle
......@@ -1559,6 +1559,7 @@ def zeros_like(x, out=None):
return out
@deprecated(since="2.0.0", update_to="paddle.diag")
def diag(diagonal):
"""
:alias_main: paddle.diag
......
......@@ -17,11 +17,173 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import Program, program_guard
class TestDiagV2Op(OpTest):
def setUp(self):
self.op_type = "diag_v2"
self.x = np.random.rand(10, 10)
self.offset = 0
self.padding_value = 0.0
self.out = np.diag(self.x, self.offset)
self.init_config()
self.inputs = {'X': self.x}
self.attrs = {
'offset': self.offset,
'padding_value': self.padding_value
}
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
def init_config(self):
pass
class TestDiagV2OpCase1(TestDiagV2Op):
def init_config(self):
self.offset = 1
self.out = np.diag(self.x, self.offset)
class TestDiagV2OpCase2(TestDiagV2Op):
def init_config(self):
self.offset = -1
self.out = np.diag(self.x, self.offset)
class TestDiagV2OpCase3(TestDiagV2Op):
def init_config(self):
self.x = np.random.randint(-10, 10, size=(10, 10))
self.out = np.diag(self.x, self.offset)
class TestDiagV2OpCase4(TestDiagV2Op):
def init_config(self):
self.x = np.random.rand(100)
self.padding_value = 8
n = self.x.size
self.out = self.padding_value * np.ones((n, n)) + np.diag(
self.x, self.offset) - np.diag(self.padding_value * np.ones(n))
class TestDiagV2Error(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_diag_v2_type():
x = [1, 2, 3]
output = paddle.diag(x)
self.assertRaises(TypeError, test_diag_v2_type)
class TestDiagV2API(unittest.TestCase):
def setUp(self):
self.input_np = np.random.random(size=(10, 10)).astype(np.float32)
self.expected0 = np.diag(self.input_np)
self.expected1 = np.diag(self.input_np, k=1)
self.expected2 = np.diag(self.input_np, k=-1)
self.input_np2 = np.random.rand(100)
self.offset = 0
self.padding_value = 8
n = self.input_np2.size
self.expected3 = self.padding_value * np.ones(
(n, n)) + np.diag(self.input_np2, self.offset) - np.diag(
self.padding_value * np.ones(n))
self.input_np3 = np.random.randint(-10, 10, size=(100)).astype(np.int64)
self.padding_value = 8.0
n = self.input_np3.size
self.expected4 = self.padding_value * np.ones(
(n, n)) + np.diag(self.input_np3, self.offset) - np.diag(
self.padding_value * np.ones(n))
self.padding_value = -8
self.expected5 = self.padding_value * np.ones(
(n, n)) + np.diag(self.input_np3, self.offset) - np.diag(
self.padding_value * np.ones(n))
def run_imperative(self):
x = paddle.to_tensor(self.input_np)
y = paddle.diag(x)
self.assertTrue(np.allclose(y.numpy(), self.expected0))
y = paddle.diag(x, offset=1)
self.assertTrue(np.allclose(y.numpy(), self.expected1))
y = paddle.diag(x, offset=-1)
self.assertTrue(np.allclose(y.numpy(), self.expected2))
x = paddle.to_tensor(self.input_np2)
y = paddle.diag(x, padding_value=8)
self.assertTrue(np.allclose(y.numpy(), self.expected3))
x = paddle.to_tensor(self.input_np3)
y = paddle.diag(x, padding_value=8.0)
self.assertTrue(np.allclose(y.numpy(), self.expected4))
y = paddle.diag(x, padding_value=-8)
self.assertTrue(np.allclose(y.numpy(), self.expected5))
def run_static(self, use_gpu=False):
x = paddle.data(name='input', shape=[10, 10], dtype='float32')
x2 = paddle.data(name='input2', shape=[100], dtype='float64')
x3 = paddle.data(name='input3', shape=[100], dtype='int64')
result0 = paddle.diag(x)
result1 = paddle.diag(x, offset=1)
result2 = paddle.diag(x, offset=-1)
result3 = paddle.diag(x, name='aaa')
result4 = paddle.diag(x2, padding_value=8)
result5 = paddle.diag(x3, padding_value=8.0)
result6 = paddle.diag(x3, padding_value=-8)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
res0, res1, res2, res4, res5, res6 = exe.run(
feed={
"input": self.input_np,
"input2": self.input_np2,
'input3': self.input_np3
},
fetch_list=[result0, result1, result2, result4, result5, result6])
self.assertTrue(np.allclose(res0, self.expected0))
self.assertTrue(np.allclose(res1, self.expected1))
self.assertTrue(np.allclose(res2, self.expected2))
self.assertTrue('aaa' in result3.name)
self.assertTrue(np.allclose(res4, self.expected3))
self.assertTrue(np.allclose(res5, self.expected4))
self.assertTrue(np.allclose(res6, self.expected5))
def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace())
self.run_imperative()
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
self.run_static()
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
paddle.disable_static(place=paddle.fluid.CUDAPlace(0))
self.run_imperative()
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
self.run_static(use_gpu=True)
class TestDiagOp(OpTest):
def setUp(self):
self.op_type = "diag"
......
......@@ -29,7 +29,6 @@ from paddle.common_ops_import import *
# TODO: define functions to get create a tensor
from ..fluid.layers import crop_tensor #DEFINE_ALIAS
from ..fluid.layers import diag #DEFINE_ALIAS
from ..fluid.layers import fill_constant #DEFINE_ALIAS
from ..fluid.layers import linspace #DEFINE_ALIAS
import paddle
......@@ -903,3 +902,92 @@ def meshgrid(*args, **kwargs):
type='meshgrid', inputs={'X': list(args)}, outputs={'Out': out})
return out
def diag(x, offset=0, padding_value=0, name=None):
"""
If ``x`` is a vector (1-D tensor), a 2-D square tensor whth the elements of ``x`` as the diagonal is returned.
If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal elements of ``x`` is returned.
The argument ``offset`` controls the diagonal offset:
If ``offset`` = 0, it is the main diagonal.
If ``offset`` > 0, it is superdiagonal.
If ``offset`` < 0, it is subdiagonal.
Args:
x (Tensor): The input tensor. Its shape is either 1-D or 2-D. Its data type should be float32, float64, int32, int64.
offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal.
padding_value (int|float, optional): Use this value to fill the area outside the specified diagonal band. Only takes effect when the input is a 1-D Tensor. The default value is 0.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, a square matrix or a vector. The output data type is the same as input data type.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
x = paddle.to_tensor([1, 2, 3])
y = paddle.diag(x)
print(y.numpy())
# [[1 0 0]
# [0 2 0]
# [0 0 3]]
y = paddle.diag(x, offset=1)
print(y.numpy())
# [[0 1 0 0]
# [0 0 2 0]
# [0 0 0 3]
# [0 0 0 0]]
y = paddle.diag(x, padding_value=6)
print(y.numpy())
# [[1 6 6]
# [6 2 6]
# [6 6 3]]
.. code-block:: python
import paddle
paddle.disable_static()
x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
y = paddle.diag(x)
print(y.numpy())
# [1 5]
y = paddle.diag(x, offset=1)
print(y.numpy())
# [2 6]
y = paddle.diag(x, offset=-1)
print(y.numpy())
# [4]
"""
if in_dygraph_mode():
return core.ops.diag_v2(x, "offset", offset, "padding_value",
padding_value)
check_type(x, 'x', (Variable), 'diag_v2')
check_dtype(x.dtype, 'x', ['float32', 'float64', 'int32', 'int64'],
'diag_v2')
helper = LayerHelper("diag_v2", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='diag_v2',
inputs={'X': x},
outputs={'Out': out},
attrs={'offset': offset,
'padding_value': padding_value})
out.stop_gradient = True
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册