未验证 提交 077e5a0f 编写于 作者: L Li Fuchen 提交者: GitHub

Add trace op (#23873)

* add trace op, test=develop

* Optimized the implementation of trace op, test=develop

* fix a bug of include in trace_op.h, test=develop

* move trace API from creation to math, test=develop

* modified en doc. test=develop

* add complex trace api

* add complex sum api, test=develop

* modified en doc of complex sum and trace, test=develop

* modified doc and trace API, test=develop

* modified en doc of trace and sum, test=develop

* modified comment in complex kron API, test=develop

* OP Should Not Have Unused Input, test=develop

* add GetExpectedKernelType, test=develop
上级 fa43d74a
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/trace_op.h"
namespace paddle {
namespace operators {
class TraceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::NotFound("Input of TraceOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output of TraceOp is not found."));
int dim1 = ctx->Attrs().Get<int>("dim1");
int dim2 = ctx->Attrs().Get<int>("dim2");
auto x_dims = ctx->GetInputDim("Input");
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 : dim2;
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
platform::errors::OutOfRange(
"trace requires an tensor of at least two dimensions"));
PADDLE_ENFORCE_LT(
dim1_, x_dims.size(),
platform::errors::OutOfRange(
"Attr(dim1) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()), (x_dims.size() - 1), dim1));
PADDLE_ENFORCE_LT(
dim2_, x_dims.size(),
platform::errors::OutOfRange(
"Attr(dim2) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()), (x_dims.size() - 1), dim2));
PADDLE_ENFORCE_NE(dim1_, dim2_,
platform::errors::InvalidArgument(
"The dimensions should not be identical "
"%ld vs %ld.",
dim1, dim2));
auto sizes = vectorize(x_dims);
if (x_dims.size() == 2) {
sizes.clear();
sizes.push_back(1);
} else {
sizes.erase(sizes.begin() + std::max(dim1_, dim2_));
sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
}
ctx->SetOutputDim("Out", framework::make_ddim(sizes));
}
};
class TraceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor) The input tensor, from which the diagonals are taken.");
AddOutput("Out", "(Tensor) the sum along diagonals of the input tensor");
AddAttr<int>(
"offset",
R"DOC((int, default 0), offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0.
)DOC")
.SetDefault(0);
AddAttr<int>(
"dim1",
R"DOC((int, default 0), the first dim of the 2-D planes from which the diagonals should be taken.
Can be both positive and negative. Default: 0.
)DOC")
.SetDefault(-2);
AddAttr<int>(
"dim2",
R"DOC((int, default 1), the second dim of the 2-D planes from which the diagonals should be taken.
Can be both positive and negative. Default: 1.
)DOC")
.SetDefault(-1);
AddComment(R"DOC(
Trace Operator.
Return the sum along diagonals of the input tensor.
The behavior of this operator is similar to how `numpy.trace` works.
If Input is 2-D, returns the sum of diagonal.
If Input has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from
the 2-D planes specified by dim1 and dim2.
)DOC");
}
};
class TraceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::NotFound("Input(Input) of TraceOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Input")), true,
platform::errors::NotFound(
"Output(Input@GRAD) of TraceGradOp is not found."));
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class TraceGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("trace_grad");
grad_op->SetInput("Input", this->Input("Input"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("Input"),
this->InputGrad("Input"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInference,
"Input");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
ops::TraceGradOpMaker<paddle::framework::OpDesc>,
ops::TraceGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad,
ops::TraceGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, float>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/trace_op.h"
namespace paddle {
namespace operators {
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
};
template <typename DeviceContext, typename T>
class TraceCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input");
auto* out = context.Output<framework::Tensor>("Out");
const int64_t offset = context.Attr<int>("offset");
const int64_t dim1 = context.Attr<int>("dim1");
const int64_t dim2 = context.Attr<int>("dim2");
T* out_data = out->mutable_data<T>(context.GetPlace());
const framework::Tensor diag =
Diagonal<DeviceContext, T>(context, input, offset, dim1, dim2);
if (diag.numel() > 0) {
auto stream = context.cuda_device_context().stream();
std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size());
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
diag, out, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace platform = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
trace, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
platform::float16>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
platform::float16>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
struct DiagFunctor {
DiagFunctor(const T* input, const int64_t* diag_stride,
const int64_t* ret_strides, int64_t pos, int64_t dim_size,
T* diag)
: input_(input),
diag_stride_(diag_stride),
ret_strides_(ret_strides),
pos_(pos),
dim_size_(dim_size),
diag_(diag) {}
HOSTDEVICE void operator()(size_t idx) const {
int64_t position = pos_;
int64_t num = idx;
for (int64_t i = 0; i < dim_size_; i++) {
position += num / diag_stride_[i] * ret_strides_[i];
num = num % diag_stride_[i];
}
diag_[idx] = input_[position];
}
const T* input_;
const int64_t* diag_stride_;
const int64_t* ret_strides_;
int64_t pos_;
int64_t dim_size_;
T* diag_;
};
template <typename T>
struct TraceGradFunctor {
TraceGradFunctor(const T* d_out, const int64_t* out_stride,
const int64_t* x_strides, int64_t pos, int64_t dim_size,
int64_t dim1, int64_t dim2, int64_t diag_size, T* d_x)
: d_out_(d_out),
out_stride_(out_stride),
x_strides_(x_strides),
pos_(pos),
dim_size_(dim_size),
dim1_(dim1),
dim2_(dim2),
diag_size_(diag_size),
d_x_(d_x) {}
HOSTDEVICE void operator()(size_t idx) const {
int64_t num = idx - pos_;
int64_t position = 0;
if (num >= 0) {
int64_t dim1 = 0;
int64_t dim2 = 0;
int64_t out_idx = 0;
for (int64_t i = 0; i < dim_size_; i++) {
if (i != dim1_ && i != dim2_) {
position += num / x_strides_[i] * out_stride_[out_idx++];
} else if (i == dim1_) {
dim1 = num / x_strides_[i];
} else {
dim2 = num / x_strides_[i];
}
num = num % x_strides_[i];
}
if (dim1 == dim2 && dim1 < diag_size_) {
d_x_[idx] = d_out_[position];
}
}
}
const T* d_out_;
const int64_t* out_stride_;
const int64_t* x_strides_;
int64_t pos_;
int64_t dim_size_;
int64_t dim1_;
int64_t dim2_;
int64_t diag_size_;
T* d_x_;
};
template <typename DeviceContext, typename T>
framework::Tensor Diagonal(const framework::ExecutionContext& context,
const framework::Tensor* input, int64_t offset,
int64_t dim1, int64_t dim2) {
auto* input_data = input->data<T>();
auto input_dims = input->dims();
auto input_stride = framework::stride(input_dims);
auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1;
auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2;
auto len1 = input_dims[std::min(dim1_, dim2_)];
auto len2 = input_dims[std::max(dim1_, dim2_)];
auto stride1 = input_stride[std::min(dim1_, dim2_)];
auto stride2 = input_stride[std::max(dim1_, dim2_)];
int offset_stride = 0;
if (offset >= 0) {
offset_stride = stride2;
len2 -= offset;
} else {
offset_stride = stride1;
len1 += offset;
}
int diag_size = len2 < len1 ? len2 : len1;
if (diag_size > 0) {
auto ret_strides = vectorize(input_stride);
auto ret_dims = vectorize(input_dims);
ret_strides.erase(ret_strides.begin() + std::max(dim1_, dim2_));
ret_strides.erase(ret_strides.begin() + std::min(dim1_, dim2_));
ret_dims.erase(ret_dims.begin() + std::max(dim1_, dim2_));
ret_dims.erase(ret_dims.begin() + std::min(dim1_, dim2_));
if (ret_strides.empty()) {
ret_strides.push_back(1);
ret_dims.push_back(1);
}
ret_strides.push_back(stride1 + stride2);
ret_dims.push_back(diag_size);
framework::Tensor diag;
framework::DDim diag_dims = framework::make_ddim(ret_dims);
auto dig_stride = framework::stride(diag_dims);
auto diag_data = diag.mutable_data<T>(diag_dims, context.GetPlace());
int64_t pos = std::abs(offset) * offset_stride;
int64_t dim_size = ret_strides.size();
#ifdef __NVCC__
thrust::device_vector<int64_t> diag_vec(vectorize(dig_stride));
const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data());
thrust::device_vector<int64_t> ret_vec(ret_strides);
const int64_t* ret_arr = thrust::raw_pointer_cast(ret_vec.data());
#else
auto* diag_arr = dig_stride.Get();
const auto* ret_arr = ret_strides.data();
#endif
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel());
DiagFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
diag_data);
for_range(functor);
return diag;
} else {
return {};
}
}
template <typename DeviceContext, typename T>
class TraceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input");
auto* out = context.Output<framework::Tensor>("Out");
const int64_t offset = context.Attr<int>("offset");
const int64_t dim1 = context.Attr<int>("dim1");
const int64_t dim2 = context.Attr<int>("dim2");
auto output_dims = out->dims();
out->mutable_data<T>(context.GetPlace());
const framework::Tensor diag =
Diagonal<DeviceContext, T>(context, input, offset, dim1, dim2);
if (diag.numel() > 0) {
auto x = framework::EigenMatrix<T>::Reshape(diag, diag.dims().size() - 1);
auto output = framework::EigenVector<T>::Flatten(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({1});
output.device(place) = x.sum(reduce_dim);
out->Resize(output_dims);
}
}
};
template <typename DeviceContext, typename T>
class TraceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
int64_t offset = context.Attr<int>("offset");
int64_t dim1 = context.Attr<int>("dim1");
int64_t dim2 = context.Attr<int>("dim2");
auto input_dims = d_x->dims();
auto input_stride = framework::stride(input_dims);
auto output_dims = d_out->dims();
auto output_stride = framework::stride(output_dims);
auto* out_data = d_out->data<T>();
T* x_data = d_x->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, d_x, static_cast<T>(0.0));
auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1;
auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2;
auto len1 = input_dims[std::min(dim1_, dim2_)];
auto len2 = input_dims[std::max(dim1_, dim2_)];
auto stride1 = input_stride[std::min(dim1_, dim2_)];
auto stride2 = input_stride[std::max(dim1_, dim2_)];
int offset_stride = 0;
if (offset >= 0) {
offset_stride = stride2;
len2 -= offset;
} else {
offset_stride = stride1;
len1 += offset;
}
int64_t diag_size = len2 < len1 ? len2 : len1;
int64_t pos = std::abs(offset) * offset_stride;
if (diag_size > 0) {
#ifdef __NVCC__
thrust::device_vector<int64_t> output_vec(vectorize(output_stride));
const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data());
thrust::device_vector<int64_t> input_vec(vectorize(input_stride));
const int64_t* input_arr = thrust::raw_pointer_cast(input_vec.data());
#else
const auto* output_arr = output_stride.Get();
const auto* input_arr = input_stride.Get();
#endif
platform::ForRange<DeviceContext> for_range(dev_ctx, d_x->numel());
TraceGradFunctor<T> functor(out_data, output_arr, input_arr, pos,
input_dims.size(), dim1_, dim2_, diag_size,
x_data);
for_range(functor);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -173,6 +173,7 @@ from .tensor.math import erf #DEFINE_ALIAS
from .tensor.math import addcmul #DEFINE_ALIAS
from .tensor.math import addmm #DEFINE_ALIAS
from .tensor.math import clamp #DEFINE_ALIAS
from .tensor.math import trace #DEFINE_ALIAS
from .tensor.math import kron #DEFINE_ALIAS
# from .tensor.random import gaussin #DEFINE_ALIAS
# from .tensor.random import uniform #DEFINE_ALIAS
......
......@@ -19,8 +19,13 @@ from ...fluid import layers
from ...tensor import math
__all__ = [
'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div',
'kron'
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
'elementwise_div',
'kron',
'trace',
'sum',
]
......@@ -231,6 +236,106 @@ def elementwise_div(x, y, axis=-1, name=None):
name=name)
def trace(input, offset=0, dim1=0, dim2=1, name=None):
"""
The layer to compute the trace for a complex number tensor. input :attr:`input` must be a ComplexVariable.
See the detailed description for the function and other arguments
in :ref:`api_tensor_math_trace` .
Args:
input(ComplexVariable): The input ComplexVariable. Must be at least 2-dimensional.
The supported data types include complex64 and complex128.
offset(int, optional): Which diagonals in input tensor will be taken. Default: 0 (main diagonals).
dim1(int, optional): The first dimension with respect to take diagonal. Default: 0.
dim2(int, optional): The second dimension with respect to take diagonal. Default: 1.
name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.
Returns:
ComplexVariable: The trace result of input tensor, it's data type is the same as input data type.
Examples:
.. code-block:: python
import paddle
import paddle.fluid.dygraph as dg
import numpy as np
case1 = np.random.randn(3, 10, 10).astype('float64') + 1j * np.random.randn(3, 10, 10).astype('float64')
with dg.guard():
case1 = dg.to_variable(case1)
data1 = paddle.complex.trace(case1, offset=1, dim1=1, dim2=2) # data1.shape = [3]
"""
complex_variable_exists([input], "trace")
real = math.trace(input.real, offset, dim1, dim2, name)
imag = math.trace(input.imag, offset, dim1, dim2, name)
return ComplexVariable(real, imag)
def sum(input, dim=None, keep_dim=False, name=None):
"""
The layer to compute the sum for a complex number tensor elements over the given dimension. input :attr:`input` must be a ComplexVariable.
See the detailed description for the function and other arguments
in :ref:`api_tensor_math_sum` .
Args:
input(ComplexVariable): The input ComplexVariable with any number of dimensions.
The supported data types include complex64 and complex128.
dim (list|int, optional): The dimensions along which the sum is performed. If
:attr:`None`, sum all elements of :attr:`input` and return a
Tensor variable with a single element, otherwise must be in the
range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`,
the dimension to reduce is :math:`rank + dim[i]`.
keep_dim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true, default
value is False.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
ComplexVariable: Results of summation operation on the specified dim of input tensor,
it's data type is the same as input.
Raises:
ValueError: the :attr:`dtype` must be float64 or int64.
Examples:
.. code-block:: python
import paddle.complex as cpx
import paddle.fluid.dygraph as dg
import numpy as np
with dg.guard():
# x is a Tensor variable with following elements:
# [[0.2, 0.3, 0.5, 0.9],
# [0.1, 0.2, 0.6, 0.7]]
# Each example is followed by the corresponding output tensor.
x = np.array([[0.2, 0.3, 0.5, 0.9],[0.1, 0.2, 0.6, 0.7]]) + 1j * np.array([[0.3, 0.4, 0.5, 0.2],[0.3, 0.6, 0.8, 0.3]])
x = dg.to_variable(x)
out1 = cpx.sum(x) # [3.5+3.4j]
out2 = cpx.sum(x, dim=0) # [0.3+0.6j, 0.5+1.j, 1.1+1.3j, 1.6+0.5j]
out3 = cpx.sum(x, dim=-1) # [1.9+1.4j, 1.6+2.j]
out4 = cpx.sum(x, dim=1, keep_dim=True) # [[1.9+1.4j], [1.6+2.j]]
# y is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1, 2], [3, 4]],
# [[5, 6], [7, 8]]]
# Each example is followed by the corresponding output tensor.
y = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + 1j * np.array([[[4, 3], [2, 1]], [[8, 7], [6, 5]]])
y = dg.to_variable(y)
out5 = cpx.sum(y, dim=[1, 2]) # [10.+10.j, 26.+26.j]
out6 = cpx.sum(y, dim=[0, 1]) # [16.+20.j, 20.+16.j]
"""
complex_variable_exists([input], "sum")
real = math.sum(input.real, dim=dim, keep_dim=keep_dim, name=name)
imag = math.sum(input.imag, dim=dim, keep_dim=keep_dim, name=name)
return ComplexVariable(real, imag)
def kron(x, y, name=None):
"""
The kronecker product of two complex tensors. At least one of inputs :attr:`x`
......@@ -282,7 +387,7 @@ def kron(x, y, name=None):
complex_variable_exists([x, y], "kron")
# X = A + Bi, Y = C+Di
# kron(A, B) = kron(A, C) - kron(B, D) + (kron(A, D) + kron(B, C))i
# kron(X, Y) = kron(A, C) - kron(B, D) + (kron(A, D) + kron(B, C))i
(a, b) = (x.real, x.imag) if is_complex(x) else (x, None)
(c, d) = (y.real, y.imag) if is_complex(y) else (y, None)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from numpy.random import random as rand
import paddle.complex as cpx
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
class TestComplexSumLayer(unittest.TestCase):
def setUp(self):
self._dtype = "float64"
self._places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0))
def test_complex_x(self):
input = rand([2, 10, 10]).astype(self._dtype) + 1j * rand(
[2, 10, 10]).astype(self._dtype)
for place in self._places:
with dg.guard(place):
var_x = dg.to_variable(input)
result = cpx.sum(var_x, dim=[1, 2]).numpy()
target = np.sum(input, axis=(1, 2))
self.assertTrue(np.allclose(result, target))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from numpy.random import random as rand
import paddle.complex as cpx
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
class TestComplexTraceLayer(unittest.TestCase):
def setUp(self):
self._dtype = "float64"
self._places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0))
def test_complex_x(self):
input = rand([2, 20, 2, 3]).astype(self._dtype) + 1j * rand(
[2, 20, 2, 3]).astype(self._dtype)
for place in self._places:
with dg.guard(place):
var_x = dg.to_variable(input)
result = cpx.trace(var_x, offset=1, dim1=0, dim2=2).numpy()
target = np.trace(input, offset=1, axis1=0, axis2=2)
self.assertTrue(np.allclose(result, target))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.nn.functional as F
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.tensor as tensor
class TestTraceOp(OpTest):
def setUp(self):
self.op_type = "trace"
self.init_config()
self.outputs = {'Out': self.target}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Input'], 'Out')
def init_config(self):
self.case = np.random.randn(20, 6).astype('float64')
self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'dim1': 0, 'dim2': 1}
self.target = np.trace(self.inputs['Input'])
class TestTraceOpCase1(TestTraceOp):
def init_config(self):
self.case = np.random.randn(2, 20, 2, 3).astype('float32')
self.inputs = {'Input': self.case}
self.attrs = {'offset': 1, 'dim1': 0, 'dim2': 2}
self.target = np.trace(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['dim1'],
axis2=self.attrs['dim2'])
class TestTraceOpCase2(TestTraceOp):
def init_config(self):
self.case = np.random.randn(2, 20, 2, 3).astype('float32')
self.inputs = {'Input': self.case}
self.attrs = {'offset': -5, 'dim1': 1, 'dim2': -1}
self.target = np.trace(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['dim1'],
axis2=self.attrs['dim2'])
class TestTraceAPICase(unittest.TestCase):
def test_case1(self):
case = np.random.randn(2, 20, 2, 3).astype('float32')
data1 = fluid.data(name='data1', shape=[2, 20, 2, 3], dtype='float32')
out1 = tensor.trace(data1)
out2 = tensor.trace(data1, offset=-5, dim1=1, dim2=-1)
place = core.CPUPlace()
exe = fluid.Executor(place)
results = exe.run(fluid.default_main_program(),
feed={"data1": case},
fetch_list=[out1, out2],
return_numpy=True)
target1 = np.trace(case)
target2 = np.trace(case, offset=-5, axis1=1, axis2=-1)
self.assertTrue(np.allclose(results[0], target1))
self.assertTrue(np.allclose(results[1], target2))
if __name__ == "__main__":
unittest.main()
......@@ -151,6 +151,7 @@ from .math import erf #DEFINE_ALIAS
from .math import addcmul #DEFINE_ALIAS
from .math import addmm #DEFINE_ALIAS
from .math import clamp #DEFINE_ALIAS
from .math import trace #DEFINE_ALIAS
from .math import kron #DEFINE_ALIAS
# from .random import gaussin #DEFINE_ALIAS
# from .random import uniform #DEFINE_ALIAS
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from __future__ import print_function
from ..fluid.framework import Variable, in_dygraph_mode
from ..fluid.framework import Variable
from ..fluid.initializer import Constant
from ..fluid.layers import core
from ..fluid.layer_helper import LayerHelper
......
......@@ -14,12 +14,13 @@
"""
math functions
"""
from __future__ import print_function
from paddle.common_ops_import import *
from ..fluid import layers
from ..fluid.framework import core, _varbase_creator
from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from ..fluid.layers.layer_function_generator import _generate_doc_string_
import sys
......@@ -111,6 +112,7 @@ __all__ = [
'addcmul',
'addmm',
'clamp',
'trace',
'kron'
]
......@@ -1520,6 +1522,99 @@ def clamp(input, min=None, max=None, output=None, name=None):
return output
def trace(input, offset=0, dim1=0, dim2=1, out=None, name=None):
"""
This OP computes the sum along diagonals of the input tensor.
If ``input`` is 2D, returns the sum of diagonal.
If ``input`` has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from
the 2D planes specified by dim1 and dim2. By default, the 2D planes formed by the first and second dimensions
of the input tensor.
The argument ``offset`` determines where diagonals are taken from input tensor:
- If offset = 0, it is the main diagonal.
- If offset > 0, it is above the main diagonal.
- If offset < 0, it is below the main diagonal.
Args:
input(Variable): The input tensor. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64.
offset(int, optional): Which diagonals in input tensor will be taken. Default: 0 (main diagonals).
dim1(int, optional): The first dimension with respect to take diagonal. Default: 0.
dim2(int, optional): The second dimension with respect to take diagonal. Default: 1.
name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.
Returns:
Variable: the output data type is the same as input data type.
Examples:
.. code-block:: python
import paddle
import paddle.fluid.dygraph as dg
import numpy as np
case1 = np.random.randn(2, 3).astype('float32')
case2 = np.random.randn(3, 10, 10).astype('float32')
case3 = np.random.randn(3, 10, 5, 10).astype('float32')
with dg.guard():
case1 = dg.to_variable(case1)
case2 = dg.to_variable(case2)
case3 = dg.to_variable(case3)
data1 = paddle.trace(case1) # data1.shape = [1]
data2 = paddle.trace(case2, offset=1, dim1=1, dim2=2) # data2.shape = [3]
data3 = paddle.trace(case3, offset=-3, dim1=1, dim2=-1) # data2.shape = [3, 5]
"""
inputs = {'Input': [input]}
attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2}
def __check_input(input, offset, dim1, dim2):
check_dtype(input.dtype, 'Input',
['int32', 'int64', 'float16', 'float32', 'float64'],
'trace')
input_shape = list(input.shape)
assert len(input_shape) >= 2, \
"The input must be at least 2-dimensional, " \
"But received Input's dimensional: %s.\n" % \
len(input_shape)
dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1
dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2
assert dim1_ < len(input_shape), \
"The argument dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
% (-(len(input_shape)), len(input_shape) - 1, dim1)
assert dim2_ < len(input_shape), \
"The argument dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
% (-(len(input_shape)), len(input_shape) - 1, dim2)
assert dim1_ != dim2_, \
"dim1 and dim2 cannot be the same dimension." \
"But received dim1 = %d, dim2 = %d\n"%(dim1, dim2)
if not in_dygraph_mode():
__check_input(input, offset, dim1, dim2)
helper = LayerHelper('trace', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=input.dtype)
else:
check_variable_and_dtype(out, 'out', ['float16', 'float32', 'float64', 'int32', 'int64'], 'trace')
helper.append_op(
type='trace',
inputs={'Input': [input]},
attrs={'offset': offset,
'dim1': dim1,
'dim2': dim2},
outputs={'Out': [out]})
return out
@templatedoc(op_type="kron")
def kron(x, y, out=None, name=None):
"""${comment}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册