未验证 提交 ad106290 编写于 作者: Z zhangbo9674 提交者: GitHub

[API/OP]Add a new API paddle.diagonal (#33586)

* new api diagonal, test=develop

* add new api diagonal, test=develop

* new api diagonal, test=develop

* add new api paddle.diagonal, test=develop

* use framework::stride replace ComputeDimStride

* replace cudaMalloc/cudaMemcpy by TensorFormVector in cudaKernel and cudaGradKernel

* perfect funciton: when attr(offset) is exceed attr(axis1) or attr(axis2), set the diagonal dim is 0

* fix RP-Mac-CI bug: replace framework::stride() by ComputDimStride.

* perfect code-block

* perfect code of python API diagonal

* api supports dtype of float16 and bool

* api supports dtype of float16 and bool

* modify unittest code

* modify unittest code

* perfect dtype describe

* perfect code-block
上级 1cfbcb14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/diagonal_op.h"
namespace paddle {
namespace operators {
class DiagonalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "diagonal");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "diagonal");
int offset_ = ctx->Attrs().Get<int>("offset");
int axis1 = ctx->Attrs().Get<int>("axis1");
int axis2 = ctx->Attrs().Get<int>("axis2");
auto x_dims = ctx->GetInputDim("Input");
int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1;
int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2;
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
platform::errors::OutOfRange("Input's dim is out of range (expected at "
"least 2 dimensions, but got %ld).",
x_dims.size()));
PADDLE_ENFORCE_LT(
axis1_, x_dims.size(),
platform::errors::OutOfRange(
"Attr(axis1) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()), (x_dims.size() - 1), axis1));
PADDLE_ENFORCE_LT(
axis2_, x_dims.size(),
platform::errors::OutOfRange(
"Attr(axis2) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()), (x_dims.size() - 1), axis2));
PADDLE_ENFORCE_NE(axis1_, axis2_,
platform::errors::InvalidArgument(
"The dimensions should not be identical "
"%d vs %d.",
axis1, axis2));
auto out_dims = vectorize(x_dims);
// from out_dims get the dim size of axis1_.
auto axis1_size = out_dims[axis1_];
auto axis2_size = out_dims[axis2_];
// delete two dims by attr axis1 and axis2 from out_dims.
/* example:
out_dim = [2, 3, 4];
axis1 = 0;
axis2 = 1;
according to the attr of axis1 and axis2, we get:
out_dim = [4].
*/
out_dims.erase(out_dims.begin() + std::max(axis1_, axis2_));
out_dims.erase(out_dims.begin() + std::min(axis1_, axis2_));
if (offset_ == 0) {
out_dims.push_back(std::min(axis1_size, axis2_size));
} else if (offset_ > 0) {
if ((axis2_size - offset_) > 0) {
out_dims.push_back(std::min(axis1_size, axis2_size - offset_));
} else {
out_dims.push_back(0);
}
} else {
if ((axis1_size + offset_) > 0) {
out_dims.push_back(std::min(axis1_size + offset_, axis2_size));
} else {
out_dims.push_back(0);
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
}
};
class DiagonalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor) The input tensor, from which the diagonals are taken.");
AddOutput(
"Out",
"(Tensor) The partial view of input with the its diagonal elements.");
AddAttr<int>(
"offset",
R"DOC((int, default 0), offset of the diagonal from the main diagonal. Can be both positive and negative. Default: 0.
)DOC")
.SetDefault(0);
AddAttr<int>(
"axis1",
R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken.
Can be either positive or negative. Default: 0.
)DOC")
.SetDefault(0);
AddAttr<int>(
"axis2",
R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken.
Can be either positive or negative. Default: 1.
)DOC")
.SetDefault(1);
AddComment(R"DOC(
Diagonal Operator.
Return a partial view of input with the its diagonal elements of the input tensor.
The behavior of this operator is similar to how `numpy.diagonal` works.
)DOC");
}
};
class DiagonalGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "DiagonalGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
framework::GradVarName("Input"), "DiagonalGrad");
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class DiagonalGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("diagonal_grad");
grad_op->SetInput("Input", this->Input("Input"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("Input"),
this->InputGrad("Input"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(DiagonalGradNoNeedBufferVarsInferer,
"Input");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(diagonal, ops::DiagonalOp, ops::DiagonalOpMaker,
ops::DiagonalGradOpMaker<paddle::framework::OpDesc>,
ops::DiagonalGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(diagonal_grad, ops::DiagonalGradOp,
ops::DiagonalGradNoNeedBufferVarsInferer)
REGISTER_OP_CPU_KERNEL(diagonal, ops::DiagonalKernel<int>,
ops::DiagonalKernel<int64_t>, ops::DiagonalKernel<float>,
ops::DiagonalKernel<double>, ops::DiagonalKernel<bool>);
REGISTER_OP_CPU_KERNEL(diagonal_grad, ops::DiagonalGradKernel<int>,
ops::DiagonalGradKernel<int64_t>,
ops::DiagonalGradKernel<float>,
ops::DiagonalGradKernel<double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diagonal_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, int X_DIM_SIZE, int OUT_DIM_SIZE>
__global__ void Diagonal(const T* data1, T* data2, const int64_t offset_,
int64_t axis1_, int64_t axis2_, int64_t* x_stride,
int64_t* out_stride, int64_t numel, bool is_grad) {
CUDA_KERNEL_LOOP(idx, numel) {
int64_t idx_dim[X_DIM_SIZE] = {0};
int64_t temp = 0;
for (size_t i = 0; i < X_DIM_SIZE - 1; i++) {
idx_dim[i] = (idx - temp) / x_stride[i];
temp = temp + idx_dim[i] * x_stride[i];
}
idx_dim[X_DIM_SIZE - 1] = idx - temp;
int64_t axis1_dim = idx_dim[axis1_];
int64_t axis2_dim = idx_dim[axis2_];
int64_t out_dim[OUT_DIM_SIZE] = {0};
int temp_pos = 0;
for (int i = 0; i < X_DIM_SIZE; i++) {
if (i != axis1_ && i != axis2_) {
out_dim[temp_pos] = idx_dim[i];
temp_pos++;
}
}
bool flag = false;
if (offset_ == 0 && axis1_dim == axis2_dim) {
out_dim[temp_pos] = axis1_dim;
flag = true;
} else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) {
out_dim[temp_pos] = axis1_dim;
flag = true;
} else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) {
out_dim[temp_pos] = axis2_dim;
flag = true;
}
if (!is_grad) {
if (flag) {
int64_t idx_output = 0;
for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
idx_output = idx_output + out_dim[i] * out_stride[i];
}
idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1];
data2[idx_output] = data1[idx];
}
} else {
if (flag) {
int64_t idx_output = 0;
for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
idx_output = idx_output + out_dim[i] * out_stride[i];
}
idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1];
data2[idx] = data1[idx_output];
} else {
data2[idx] = static_cast<T>(0);
}
}
}
}
template <typename T>
class DiagonalCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input");
const auto* input_data = input->data<T>();
auto input_dim = input->dims().Get();
auto input_dim_size = input->dims().size();
std::vector<int64_t> res_in = vectorize(framework::stride(input->dims()));
paddle::framework::Tensor input_stride_tensor;
framework::TensorFromVector<int64_t>(res_in, context.device_context(),
&input_stride_tensor);
int64_t* input_stride = input_stride_tensor.data<int64_t>();
auto* output = context.Output<framework::Tensor>("Out");
auto* output_data = output->mutable_data<T>(context.GetPlace());
auto output_dim = output->dims().Get();
auto output_dim_size = output->dims().size();
std::vector<int64_t> res_out = vectorize(framework::stride(output->dims()));
paddle::framework::Tensor output_stride_tensor;
framework::TensorFromVector<int64_t>(res_out, context.device_context(),
&output_stride_tensor);
int64_t* output_stride = output_stride_tensor.data<int64_t>();
const int64_t offset_ = context.Attr<int>("offset");
const int64_t axis1 = context.Attr<int>("axis1");
int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1;
const int64_t axis2 = context.Attr<int>("axis2");
int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2;
int64_t numel = input->numel();
int threads = PADDLE_CUDA_NUM_THREADS;
int blocks = (numel + threads - 1) / threads;
switch (input_dim_size) {
case 2:
Diagonal<T, 2, 1><<<blocks, threads>>>(input_data, output_data, offset_,
axis1_, axis2_, input_stride,
output_stride, numel, false);
break;
case 3:
Diagonal<T, 3, 2><<<blocks, threads>>>(input_data, output_data, offset_,
axis1_, axis2_, input_stride,
output_stride, numel, false);
break;
case 4:
Diagonal<T, 4, 3><<<blocks, threads>>>(input_data, output_data, offset_,
axis1_, axis2_, input_stride,
output_stride, numel, false);
break;
case 5:
Diagonal<T, 5, 4><<<blocks, threads>>>(input_data, output_data, offset_,
axis1_, axis2_, input_stride,
output_stride, numel, false);
break;
case 6:
Diagonal<T, 6, 5><<<blocks, threads>>>(input_data, output_data, offset_,
axis1_, axis2_, input_stride,
output_stride, numel, false);
break;
case 7:
Diagonal<T, 7, 6><<<blocks, threads>>>(input_data, output_data, offset_,
axis1_, axis2_, input_stride,
output_stride, numel, false);
break;
case 8:
Diagonal<T, 8, 7><<<blocks, threads>>>(input_data, output_data, offset_,
axis1_, axis2_, input_stride,
output_stride, numel, false);
break;
case 9:
Diagonal<T, 9, 8><<<blocks, threads>>>(input_data, output_data, offset_,
axis1_, axis2_, input_stride,
output_stride, numel, false);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 10, but received %d.",
input_dim_size));
}
}
};
template <typename T>
class DiagonalGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
const auto* dout_data = dout->data<T>();
auto dout_dim = dout->dims().Get();
auto dout_dim_size = dout->dims().size();
std::vector<int64_t> res_dout = vectorize(framework::stride(dout->dims()));
paddle::framework::Tensor dout_stride_tensor;
framework::TensorFromVector<int64_t>(res_dout, context.device_context(),
&dout_stride_tensor);
int64_t* dout_stride = dout_stride_tensor.data<int64_t>();
auto* dx =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
auto* dx_data = dx->mutable_data<T>(context.GetPlace());
auto dx_dim = dx->dims().Get();
auto dx_dim_size = dx->dims().size();
std::vector<int64_t> res_dx = vectorize(framework::stride(dx->dims()));
paddle::framework::Tensor dx_stride_tensor;
framework::TensorFromVector<int64_t>(res_dx, context.device_context(),
&dx_stride_tensor);
int64_t* dx_stride = dx_stride_tensor.data<int64_t>();
const int64_t offset_ = context.Attr<int>("offset");
const int64_t axis1 = context.Attr<int>("axis1");
int64_t axis1_ = axis1 < 0 ? dx_dim_size + axis1 : axis1;
const int64_t axis2 = context.Attr<int>("axis2");
int64_t axis2_ = axis2 < 0 ? dx_dim_size + axis2 : axis2;
int64_t numel = dx->numel();
int threads = PADDLE_CUDA_NUM_THREADS;
int blocks = (numel + threads - 1) / threads;
switch (dx_dim_size) {
case 2:
Diagonal<T, 2, 1><<<blocks, threads>>>(dout_data, dx_data, offset_,
axis1_, axis2_, dx_stride,
dout_stride, numel, true);
break;
case 3:
Diagonal<T, 3, 2><<<blocks, threads>>>(dout_data, dx_data, offset_,
axis1_, axis2_, dx_stride,
dout_stride, numel, true);
break;
case 4:
Diagonal<T, 4, 3><<<blocks, threads>>>(dout_data, dx_data, offset_,
axis1_, axis2_, dx_stride,
dout_stride, numel, true);
break;
case 5:
Diagonal<T, 5, 4><<<blocks, threads>>>(dout_data, dx_data, offset_,
axis1_, axis2_, dx_stride,
dout_stride, numel, true);
break;
case 6:
Diagonal<T, 6, 5><<<blocks, threads>>>(dout_data, dx_data, offset_,
axis1_, axis2_, dx_stride,
dout_stride, numel, true);
break;
case 7:
Diagonal<T, 7, 6><<<blocks, threads>>>(dout_data, dx_data, offset_,
axis1_, axis2_, dx_stride,
dout_stride, numel, true);
break;
case 8:
Diagonal<T, 8, 7><<<blocks, threads>>>(dout_data, dx_data, offset_,
axis1_, axis2_, dx_stride,
dout_stride, numel, true);
break;
case 9:
Diagonal<T, 9, 8><<<blocks, threads>>>(dout_data, dx_data, offset_,
axis1_, axis2_, dx_stride,
dout_stride, numel, true);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of output(input@Grad) should be less than 10, but "
"received %d.",
dx_dim_size));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(diagonal, ops::DiagonalCUDAKernel<int>,
ops::DiagonalCUDAKernel<int64_t>,
ops::DiagonalCUDAKernel<float>,
ops::DiagonalCUDAKernel<double>,
ops::DiagonalCUDAKernel<plat::float16>,
ops::DiagonalCUDAKernel<bool>);
REGISTER_OP_CUDA_KERNEL(diagonal_grad, ops::DiagonalGradCUDAKernel<int>,
ops::DiagonalGradCUDAKernel<int64_t>,
ops::DiagonalGradCUDAKernel<float>,
ops::DiagonalGradCUDAKernel<double>,
ops::DiagonalGradCUDAKernel<plat::float16>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
std::vector<T> ComputeDimStride(const std::vector<T> dim) {
size_t dim_size = dim.size();
std::vector<T> dim_strides;
dim_strides.resize(dim_size);
for (size_t i = 0; i < dim_size - 1; i++) {
size_t temp_stride = 1;
for (size_t j = i + 1; j < dim_size; j++) {
temp_stride = temp_stride * dim[j];
}
dim_strides[i] = temp_stride;
}
dim_strides[dim_size - 1] = 1;
return dim_strides;
}
template <typename T>
class DiagonalKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input");
const T* input_data = input->data<T>();
auto input_dim = vectorize(input->dims());
auto input_dim_size = input_dim.size();
auto* output = context.Output<framework::Tensor>("Out");
T* output_data = output->mutable_data<T>(context.GetPlace());
auto output_dim = vectorize(output->dims());
const int64_t offset_ = context.Attr<int>("offset");
const int64_t axis1 = context.Attr<int>("axis1");
int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1;
const int64_t axis2 = context.Attr<int>("axis2");
int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2;
std::vector<int64_t> input_stride = ComputeDimStride(input_dim);
std::vector<int64_t> output_stride = ComputeDimStride(output_dim);
int64_t numel = input->numel();
for (int64_t idx = 0; idx < numel; idx++) {
std::vector<int64_t> idx_dim(input_dim_size);
int64_t temp = 0;
for (size_t i = 0; i < input_dim_size; i++) {
idx_dim[i] = (idx - temp) / input_stride[i];
temp = temp + idx_dim[i] * input_stride[i];
}
int64_t axis1_dim = idx_dim[axis1_];
int64_t axis2_dim = idx_dim[axis2_];
idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_));
idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_));
bool flag = false;
if (offset_ == 0 && axis1_dim == axis2_dim) {
idx_dim.push_back(axis1_dim);
flag = true;
} else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) {
idx_dim.push_back(axis1_dim);
flag = true;
} else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) {
idx_dim.push_back(axis2_dim);
flag = true;
}
if (flag) {
int64_t idx_output = 0;
for (size_t i = 0; i < idx_dim.size(); i++) {
idx_output = idx_output + idx_dim[i] * output_stride[i];
}
output_data[idx_output] = input_data[idx];
}
}
}
};
template <typename T>
class DiagonalGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
const T* dout_data = dout->data<T>();
auto dout_dim = vectorize(dout->dims());
auto* dx =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
T* dx_data = dx->mutable_data<T>(context.GetPlace());
auto dx_dim = vectorize(dx->dims());
auto dx_dim_size = dx_dim.size();
const int64_t offset_ = context.Attr<int>("offset");
const int64_t axis1 = context.Attr<int>("axis1");
int64_t axis1_ = axis1 < 0 ? dx_dim_size + axis1 : axis1;
const int64_t axis2 = context.Attr<int>("axis2");
int64_t axis2_ = axis2 < 0 ? dx_dim_size + axis2 : axis2;
std::vector<int64_t> dout_stride = ComputeDimStride(dout_dim);
std::vector<int64_t> dx_stride = ComputeDimStride(dx_dim);
int64_t numel = dx->numel();
for (int64_t idx = 0; idx < numel; idx++) {
std::vector<int64_t> idx_dim(dx_dim_size);
int64_t temp = 0;
for (size_t i = 0; i < dx_dim_size; i++) {
idx_dim[i] = (idx - temp) / dx_stride[i];
temp = temp + idx_dim[i] * dx_stride[i];
}
int64_t axis1_dim = idx_dim[axis1_];
int64_t axis2_dim = idx_dim[axis2_];
idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_));
idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_));
bool flag = false;
if (offset_ == 0 && axis1_dim == axis2_dim) {
idx_dim.push_back(axis1_dim);
flag = true;
} else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) {
idx_dim.push_back(axis1_dim);
flag = true;
} else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) {
idx_dim.push_back(axis2_dim);
flag = true;
}
if (flag) {
int64_t idx_output = 0;
for (size_t i = 0; i < idx_dim.size(); i++) {
idx_output = idx_output + idx_dim[i] * dout_stride[i];
}
dx_data[idx] = dout_data[idx_output];
} else {
dx_data[idx] = static_cast<T>(0);
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -203,6 +203,7 @@ from .tensor.math import erf # noqa: F401
from .tensor.math import addmm # noqa: F401
from .tensor.math import clip # noqa: F401
from .tensor.math import trace # noqa: F401
from .tensor.math import diagonal # noqa: F401
from .tensor.math import kron # noqa: F401
from .tensor.math import isfinite # noqa: F401
from .tensor.math import isinf # noqa: F401
......@@ -503,5 +504,6 @@ __all__ = [ # noqa
'check_shape',
'trunc',
'digamma',
'standard_normal'
'standard_normal',
'diagonal'
]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.nn.functional as F
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.tensor as tensor
paddle.enable_static()
class TestDiagonalOp(OpTest):
def setUp(self):
self.op_type = "diagonal"
self.init_config()
self.outputs = {'Out': self.target}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Input'], 'Out')
def init_config(self):
self.case = np.random.randn(10, 5, 2).astype('float64')
self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'])
class TestDiagonalOpCase1(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(4, 2, 4, 4).astype('float32')
self.inputs = {'Input': self.case}
self.attrs = {'offset': -2, 'axis1': 3, 'axis2': 0}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'])
class TestDiagonalOpCase2(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(100, 100).astype('int64')
self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'])
self.grad_x = np.eye(100).astype('int64')
self.grad_out = np.ones(100).astype('int64')
def test_check_grad(self):
self.check_grad(
['Input'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestDiagonalOpCase3(TestDiagonalOp):
def init_config(self):
self.case = np.random.randint(0, 2, (4, 2, 4, 4)).astype('bool')
self.inputs = {'Input': self.case}
self.attrs = {'offset': -2, 'axis1': 3, 'axis2': 0}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'])
def test_check_grad(self):
pass
class TestDiagonalAPI(unittest.TestCase):
def setUp(self):
self.shape = [10, 3, 4]
self.x = np.random.random((10, 3, 4)).astype(np.float32)
self.place = paddle.CPUPlace()
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.shape)
out = paddle.diagonal(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x}, fetch_list=[out])
out_ref = np.diagonal(self.x)
for out in res:
self.assertEqual(np.allclose(out, out_ref, rtol=1e-08), True)
def test_api_dygraph(self):
paddle.disable_static(self.place)
x_tensor = paddle.to_tensor(self.x)
out = paddle.diagonal(x_tensor)
out_ref = np.diagonal(self.x)
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
......@@ -171,6 +171,7 @@ from .math import trunc # noqa: F401
from .math import digamma # noqa: F401
from .math import neg # noqa: F401
from .math import lgamma # noqa: F401
from .math import diagonal # noqa: F401
from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
......@@ -355,8 +356,9 @@ tensor_method_func = [ #noqa
'shape',
'real',
'imag',
'digamma',
'diagonal'
'trunc'
'digamma'
'bitwise_and',
'bitwise_or',
'bitwise_xor',
......
......@@ -1696,6 +1696,114 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
outputs={'Out': [out]})
return out
def diagonal(x, offset=0, axis1=0, axis2=1, name=None):
"""
This OP computes the diagonals of the input tensor x.
If ``x`` is 2D, returns the diagonal.
If ``x`` has larger dimensions, diagonals be taken from the 2D planes specified by axis1 and axis2.
By default, the 2D planes formed by the first and second axis of the input tensor x.
The argument ``offset`` determines where diagonals are taken from input tensor x:
- If offset = 0, it is the main diagonal.
- If offset > 0, it is above the main diagonal.
- If offset < 0, it is below the main diagonal.
Args:
x(Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be bool, int32, int64, float16, float32, float64.
offset(int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals).
axis1(int, optional): The first axis with respect to take diagonal. Default: 0.
axis2(int, optional): The second axis with respect to take diagonal. Default: 1.
name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.
Returns:
Tensor: a partial view of input tensor in specify two dimensions, the output data type is the same as input data type.
Examples:
.. code-block:: python
import paddle
x = paddle.rand([2,2,3],'float32')
print(x)
# Tensor(shape=[2, 2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[[0.45661032, 0.03751532, 0.90191704],
# [0.43760979, 0.86177313, 0.65221709]],
# [[0.17020577, 0.00259554, 0.28954273],
# [0.51795638, 0.27325270, 0.18117726]]])
out1 = paddle.diagonal(x)
print(out1)
#Tensor(shape=[3, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0.45661032, 0.51795638],
# [0.03751532, 0.27325270],
# [0.90191704, 0.18117726]])
out2 = paddle.diagonal(x, offset=0, axis1=2, axis2=1)
print(out2)
#Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0.45661032, 0.86177313],
# [0.17020577, 0.27325270]])
out3 = paddle.diagonal(x, offset=1, axis1=0, axis2=1)
print(out3)
#Tensor(shape=[3, 1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0.43760979],
# [0.86177313],
# [0.65221709]])
out4 = paddle.diagonal(x, offset=0, axis1=1, axis2=2)
print(out4)
#Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0.45661032, 0.86177313],
# [0.17020577, 0.27325270]])
"""
def __check_input(input, offset, dim1, dim2):
check_dtype(x.dtype, 'Input',
['bool', 'int32', 'int64', 'float16', 'float32', 'float64'],
'diagonal')
input_shape = list(x.shape)
assert len(input_shape) >= 2, \
"The x must be at least 2-dimensional, " \
"But received Input x's dimensional: %s.\n" % \
len(input_shape)
axis1_ = axis1 if axis1 >= 0 else len(input_shape) + axis1
axis2_ = axis2 if axis2 >= 0 else len(input_shape) + axis2
assert axis1_ < len(input_shape), \
"The argument axis1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
% (-(len(input_shape)), len(input_shape) - 1, axis1)
assert axis2_ < len(input_shape), \
"The argument axis2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
% (-(len(input_shape)), len(input_shape) - 1, axis2)
assert axis1_ != axis2_, \
"axis1 and axis2 cannot be the same axis." \
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
if in_dygraph_mode():
return core.ops.diagonal(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
__check_input(input, offset, axis1, axis2)
helper = LayerHelper('diagonal', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='diagonal',
inputs={'Input': [x]},
attrs={'offset': offset,
'axis1': axis1,
'axis2': axis2},
outputs={'Out': [out]})
return out
@templatedoc(op_type="kron")
def kron(x, y, name=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册