From 077e5a0fe5fb5f75046760870473935497d31a35 Mon Sep 17 00:00:00 2001 From: Li Fuchen Date: Mon, 27 Apr 2020 15:00:28 +0800 Subject: [PATCH] Add trace op (#23873) * add trace op, test=develop * Optimized the implementation of trace op, test=develop * fix a bug of include in trace_op.h, test=develop * move trace API from creation to math, test=develop * modified en doc. test=develop * add complex trace api * add complex sum api, test=develop * modified en doc of complex sum and trace, test=develop * modified doc and trace API, test=develop * modified en doc of trace and sum, test=develop * modified comment in complex kron API, test=develop * OP Should Not Have Unused Input, test=develop * add GetExpectedKernelType, test=develop --- paddle/fluid/operators/trace_op.cc | 172 ++++++++++++ paddle/fluid/operators/trace_op.cu | 70 +++++ paddle/fluid/operators/trace_op.h | 262 ++++++++++++++++++ python/paddle/__init__.py | 1 + python/paddle/complex/tensor/math.py | 111 +++++++- .../tests/unittests/test_complex_sum_layer.py | 42 +++ .../unittests/test_complex_trace_layer.py | 42 +++ .../fluid/tests/unittests/test_trace_op.py | 89 ++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/creation.py | 2 +- python/paddle/tensor/math.py | 99 ++++++- 11 files changed, 885 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/operators/trace_op.cc create mode 100644 paddle/fluid/operators/trace_op.cu create mode 100644 paddle/fluid/operators/trace_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_complex_sum_layer.py create mode 100644 python/paddle/fluid/tests/unittests/test_complex_trace_layer.py create mode 100644 python/paddle/fluid/tests/unittests/test_trace_op.py diff --git a/paddle/fluid/operators/trace_op.cc b/paddle/fluid/operators/trace_op.cc new file mode 100644 index 00000000000..51399b68a1d --- /dev/null +++ b/paddle/fluid/operators/trace_op.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/trace_op.h" + +namespace paddle { +namespace operators { + +class TraceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Input"), true, + platform::errors::NotFound("Input of TraceOp is not found.")); + + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output of TraceOp is not found.")); + + int dim1 = ctx->Attrs().Get("dim1"); + int dim2 = ctx->Attrs().Get("dim2"); + + auto x_dims = ctx->GetInputDim("Input"); + + int dim1_ = dim1 < 0 ? x_dims.size() + dim1 : dim1; + int dim2_ = dim2 < 0 ? x_dims.size() + dim2 : dim2; + + PADDLE_ENFORCE_GE( + x_dims.size(), 2, + platform::errors::OutOfRange( + "trace requires an tensor of at least two dimensions")); + PADDLE_ENFORCE_LT( + dim1_, x_dims.size(), + platform::errors::OutOfRange( + "Attr(dim1) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), (x_dims.size() - 1), dim1)); + PADDLE_ENFORCE_LT( + dim2_, x_dims.size(), + platform::errors::OutOfRange( + "Attr(dim2) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), (x_dims.size() - 1), dim2)); + PADDLE_ENFORCE_NE(dim1_, dim2_, + platform::errors::InvalidArgument( + "The dimensions should not be identical " + "%ld vs %ld.", + dim1, dim2)); + + auto sizes = vectorize(x_dims); + if (x_dims.size() == 2) { + sizes.clear(); + sizes.push_back(1); + } else { + sizes.erase(sizes.begin() + std::max(dim1_, dim2_)); + sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); + } + ctx->SetOutputDim("Out", framework::make_ddim(sizes)); + } +}; + +class TraceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "(Tensor) The input tensor, from which the diagonals are taken."); + AddOutput("Out", "(Tensor) the sum along diagonals of the input tensor"); + AddAttr( + "offset", + R"DOC((int, default 0), offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0. + )DOC") + .SetDefault(0); + AddAttr( + "dim1", + R"DOC((int, default 0), the first dim of the 2-D planes from which the diagonals should be taken. + Can be both positive and negative. Default: 0. + )DOC") + .SetDefault(-2); + AddAttr( + "dim2", + R"DOC((int, default 1), the second dim of the 2-D planes from which the diagonals should be taken. + Can be both positive and negative. Default: 1. + )DOC") + .SetDefault(-1); + AddComment(R"DOC( +Trace Operator. +Return the sum along diagonals of the input tensor. +The behavior of this operator is similar to how `numpy.trace` works. + +If Input is 2-D, returns the sum of diagonal. +If Input has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from +the 2-D planes specified by dim1 and dim2. + +)DOC"); + } +}; +class TraceOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Input"), true, + platform::errors::NotFound("Input(Input) of TraceOp is not found.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Input")), true, + platform::errors::NotFound( + "Output(Input@GRAD) of TraceGradOp is not found.")); + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("Input")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class TraceGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("trace_grad"); + grad_op->SetInput("Input", this->Input("Input")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("Input"), + this->InputGrad("Input")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInference, + "Input"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, + ops::TraceGradOpMaker, + ops::TraceGradOpMaker); + +REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, + ops::TraceGradNoNeedBufferVarsInference); +REGISTER_OP_CPU_KERNEL( + trace, ops::TraceKernel, + ops::TraceKernel, + ops::TraceKernel, + ops::TraceKernel); +REGISTER_OP_CPU_KERNEL( + trace_grad, ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel); diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu new file mode 100644 index 00000000000..ffba298cc23 --- /dev/null +++ b/paddle/fluid/operators/trace_op.cu @@ -0,0 +1,70 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/trace_op.h" + +namespace paddle { +namespace operators { + +template +struct IdentityFunctor { + HOSTDEVICE explicit inline IdentityFunctor() {} + + HOSTDEVICE inline T operator()(const T& x) const { return x; } +}; + +template +class TraceCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* out = context.Output("Out"); + + const int64_t offset = context.Attr("offset"); + const int64_t dim1 = context.Attr("dim1"); + const int64_t dim2 = context.Attr("dim2"); + + T* out_data = out->mutable_data(context.GetPlace()); + const framework::Tensor diag = + Diagonal(context, input, offset, dim1, dim2); + if (diag.numel() > 0) { + auto stream = context.cuda_device_context().stream(); + std::vector reduce_dims; + reduce_dims.push_back(out->dims().size()); + TensorReduce>( + diag, out, reduce_dims, static_cast(0), cub::Sum(), + IdentityFunctor(), stream); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace platform = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + trace, ops::TraceCUDAKernel, + ops::TraceCUDAKernel, + ops::TraceCUDAKernel, + ops::TraceCUDAKernel, + ops::TraceCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + trace_grad, ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel, + ops::TraceGradKernel); diff --git a/paddle/fluid/operators/trace_op.h b/paddle/fluid/operators/trace_op.h new file mode 100644 index 00000000000..726efb82dd8 --- /dev/null +++ b/paddle/fluid/operators/trace_op.h @@ -0,0 +1,262 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct DiagFunctor { + DiagFunctor(const T* input, const int64_t* diag_stride, + const int64_t* ret_strides, int64_t pos, int64_t dim_size, + T* diag) + : input_(input), + diag_stride_(diag_stride), + ret_strides_(ret_strides), + pos_(pos), + dim_size_(dim_size), + diag_(diag) {} + + HOSTDEVICE void operator()(size_t idx) const { + int64_t position = pos_; + int64_t num = idx; + for (int64_t i = 0; i < dim_size_; i++) { + position += num / diag_stride_[i] * ret_strides_[i]; + num = num % diag_stride_[i]; + } + diag_[idx] = input_[position]; + } + + const T* input_; + const int64_t* diag_stride_; + const int64_t* ret_strides_; + int64_t pos_; + int64_t dim_size_; + T* diag_; +}; + +template +struct TraceGradFunctor { + TraceGradFunctor(const T* d_out, const int64_t* out_stride, + const int64_t* x_strides, int64_t pos, int64_t dim_size, + int64_t dim1, int64_t dim2, int64_t diag_size, T* d_x) + : d_out_(d_out), + out_stride_(out_stride), + x_strides_(x_strides), + pos_(pos), + dim_size_(dim_size), + dim1_(dim1), + dim2_(dim2), + diag_size_(diag_size), + d_x_(d_x) {} + + HOSTDEVICE void operator()(size_t idx) const { + int64_t num = idx - pos_; + int64_t position = 0; + if (num >= 0) { + int64_t dim1 = 0; + int64_t dim2 = 0; + int64_t out_idx = 0; + for (int64_t i = 0; i < dim_size_; i++) { + if (i != dim1_ && i != dim2_) { + position += num / x_strides_[i] * out_stride_[out_idx++]; + } else if (i == dim1_) { + dim1 = num / x_strides_[i]; + } else { + dim2 = num / x_strides_[i]; + } + num = num % x_strides_[i]; + } + if (dim1 == dim2 && dim1 < diag_size_) { + d_x_[idx] = d_out_[position]; + } + } + } + const T* d_out_; + const int64_t* out_stride_; + const int64_t* x_strides_; + int64_t pos_; + int64_t dim_size_; + int64_t dim1_; + int64_t dim2_; + int64_t diag_size_; + T* d_x_; +}; + +template +framework::Tensor Diagonal(const framework::ExecutionContext& context, + const framework::Tensor* input, int64_t offset, + int64_t dim1, int64_t dim2) { + auto* input_data = input->data(); + auto input_dims = input->dims(); + auto input_stride = framework::stride(input_dims); + auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1; + auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2; + auto len1 = input_dims[std::min(dim1_, dim2_)]; + auto len2 = input_dims[std::max(dim1_, dim2_)]; + auto stride1 = input_stride[std::min(dim1_, dim2_)]; + auto stride2 = input_stride[std::max(dim1_, dim2_)]; + + int offset_stride = 0; + if (offset >= 0) { + offset_stride = stride2; + len2 -= offset; + } else { + offset_stride = stride1; + len1 += offset; + } + int diag_size = len2 < len1 ? len2 : len1; + + if (diag_size > 0) { + auto ret_strides = vectorize(input_stride); + auto ret_dims = vectorize(input_dims); + ret_strides.erase(ret_strides.begin() + std::max(dim1_, dim2_)); + ret_strides.erase(ret_strides.begin() + std::min(dim1_, dim2_)); + ret_dims.erase(ret_dims.begin() + std::max(dim1_, dim2_)); + ret_dims.erase(ret_dims.begin() + std::min(dim1_, dim2_)); + if (ret_strides.empty()) { + ret_strides.push_back(1); + ret_dims.push_back(1); + } + ret_strides.push_back(stride1 + stride2); + ret_dims.push_back(diag_size); + framework::Tensor diag; + framework::DDim diag_dims = framework::make_ddim(ret_dims); + auto dig_stride = framework::stride(diag_dims); + auto diag_data = diag.mutable_data(diag_dims, context.GetPlace()); + + int64_t pos = std::abs(offset) * offset_stride; + int64_t dim_size = ret_strides.size(); +#ifdef __NVCC__ + thrust::device_vector diag_vec(vectorize(dig_stride)); + const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data()); + thrust::device_vector ret_vec(ret_strides); + const int64_t* ret_arr = thrust::raw_pointer_cast(ret_vec.data()); +#else + auto* diag_arr = dig_stride.Get(); + const auto* ret_arr = ret_strides.data(); +#endif + + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, diag.numel()); + DiagFunctor functor(input_data, diag_arr, ret_arr, pos, dim_size, + diag_data); + for_range(functor); + return diag; + } else { + return {}; + } +} + +template +class TraceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* out = context.Output("Out"); + + const int64_t offset = context.Attr("offset"); + const int64_t dim1 = context.Attr("dim1"); + const int64_t dim2 = context.Attr("dim2"); + + auto output_dims = out->dims(); + + out->mutable_data(context.GetPlace()); + + const framework::Tensor diag = + Diagonal(context, input, offset, dim1, dim2); + if (diag.numel() > 0) { + auto x = framework::EigenMatrix::Reshape(diag, diag.dims().size() - 1); + auto output = framework::EigenVector::Flatten(*out); + auto& place = + *context.template device_context().eigen_device(); + auto reduce_dim = Eigen::array({1}); + output.device(place) = x.sum(reduce_dim); + out->Resize(output_dims); + } + } +}; + +template +class TraceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* d_out = + context.Input(framework::GradVarName("Out")); + auto* d_x = + context.Output(framework::GradVarName("Input")); + + int64_t offset = context.Attr("offset"); + int64_t dim1 = context.Attr("dim1"); + int64_t dim2 = context.Attr("dim2"); + + auto input_dims = d_x->dims(); + auto input_stride = framework::stride(input_dims); + auto output_dims = d_out->dims(); + auto output_stride = framework::stride(output_dims); + + auto* out_data = d_out->data(); + T* x_data = d_x->mutable_data(context.GetPlace()); + + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, d_x, static_cast(0.0)); + + auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1; + auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2; + auto len1 = input_dims[std::min(dim1_, dim2_)]; + auto len2 = input_dims[std::max(dim1_, dim2_)]; + auto stride1 = input_stride[std::min(dim1_, dim2_)]; + auto stride2 = input_stride[std::max(dim1_, dim2_)]; + + int offset_stride = 0; + if (offset >= 0) { + offset_stride = stride2; + len2 -= offset; + } else { + offset_stride = stride1; + len1 += offset; + } + int64_t diag_size = len2 < len1 ? len2 : len1; + int64_t pos = std::abs(offset) * offset_stride; + if (diag_size > 0) { +#ifdef __NVCC__ + thrust::device_vector output_vec(vectorize(output_stride)); + const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data()); + thrust::device_vector input_vec(vectorize(input_stride)); + const int64_t* input_arr = thrust::raw_pointer_cast(input_vec.data()); + +#else + const auto* output_arr = output_stride.Get(); + const auto* input_arr = input_stride.Get(); +#endif + + platform::ForRange for_range(dev_ctx, d_x->numel()); + TraceGradFunctor functor(out_data, output_arr, input_arr, pos, + input_dims.size(), dim1_, dim2_, diag_size, + x_data); + for_range(functor); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 897967b0c14..5a8e874cd51 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -173,6 +173,7 @@ from .tensor.math import erf #DEFINE_ALIAS from .tensor.math import addcmul #DEFINE_ALIAS from .tensor.math import addmm #DEFINE_ALIAS from .tensor.math import clamp #DEFINE_ALIAS +from .tensor.math import trace #DEFINE_ALIAS from .tensor.math import kron #DEFINE_ALIAS # from .tensor.random import gaussin #DEFINE_ALIAS # from .tensor.random import uniform #DEFINE_ALIAS diff --git a/python/paddle/complex/tensor/math.py b/python/paddle/complex/tensor/math.py index e21cdf76202..da53c98541e 100644 --- a/python/paddle/complex/tensor/math.py +++ b/python/paddle/complex/tensor/math.py @@ -19,8 +19,13 @@ from ...fluid import layers from ...tensor import math __all__ = [ - 'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div', - 'kron' + 'elementwise_add', + 'elementwise_sub', + 'elementwise_mul', + 'elementwise_div', + 'kron', + 'trace', + 'sum', ] @@ -231,6 +236,106 @@ def elementwise_div(x, y, axis=-1, name=None): name=name) +def trace(input, offset=0, dim1=0, dim2=1, name=None): + """ + The layer to compute the trace for a complex number tensor. input :attr:`input` must be a ComplexVariable. + See the detailed description for the function and other arguments + in :ref:`api_tensor_math_trace` . + + Args: + input(ComplexVariable): The input ComplexVariable. Must be at least 2-dimensional. + The supported data types include complex64 and complex128. + offset(int, optional): Which diagonals in input tensor will be taken. Default: 0 (main diagonals). + dim1(int, optional): The first dimension with respect to take diagonal. Default: 0. + dim2(int, optional): The second dimension with respect to take diagonal. Default: 1. + name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. + + Returns: + ComplexVariable: The trace result of input tensor, it's data type is the same as input data type. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid.dygraph as dg + import numpy as np + + case1 = np.random.randn(3, 10, 10).astype('float64') + 1j * np.random.randn(3, 10, 10).astype('float64') + + with dg.guard(): + case1 = dg.to_variable(case1) + data1 = paddle.complex.trace(case1, offset=1, dim1=1, dim2=2) # data1.shape = [3] + """ + complex_variable_exists([input], "trace") + real = math.trace(input.real, offset, dim1, dim2, name) + imag = math.trace(input.imag, offset, dim1, dim2, name) + + return ComplexVariable(real, imag) + + +def sum(input, dim=None, keep_dim=False, name=None): + """ + The layer to compute the sum for a complex number tensor elements over the given dimension. input :attr:`input` must be a ComplexVariable. + See the detailed description for the function and other arguments + in :ref:`api_tensor_math_sum` . + + Args: + input(ComplexVariable): The input ComplexVariable with any number of dimensions. + The supported data types include complex64 and complex128. + dim (list|int, optional): The dimensions along which the sum is performed. If + :attr:`None`, sum all elements of :attr:`input` and return a + Tensor variable with a single element, otherwise must be in the + range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, + the dimension to reduce is :math:`rank + dim[i]`. + keep_dim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the :attr:`input` unless :attr:`keep_dim` is true, default + value is False. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + ComplexVariable: Results of summation operation on the specified dim of input tensor, + it's data type is the same as input. + + Raises: + ValueError: the :attr:`dtype` must be float64 or int64. + + Examples: + .. code-block:: python + + import paddle.complex as cpx + import paddle.fluid.dygraph as dg + import numpy as np + + with dg.guard(): + # x is a Tensor variable with following elements: + # [[0.2, 0.3, 0.5, 0.9], + # [0.1, 0.2, 0.6, 0.7]] + # Each example is followed by the corresponding output tensor. + x = np.array([[0.2, 0.3, 0.5, 0.9],[0.1, 0.2, 0.6, 0.7]]) + 1j * np.array([[0.3, 0.4, 0.5, 0.2],[0.3, 0.6, 0.8, 0.3]]) + x = dg.to_variable(x) + out1 = cpx.sum(x) # [3.5+3.4j] + out2 = cpx.sum(x, dim=0) # [0.3+0.6j, 0.5+1.j, 1.1+1.3j, 1.6+0.5j] + out3 = cpx.sum(x, dim=-1) # [1.9+1.4j, 1.6+2.j] + out4 = cpx.sum(x, dim=1, keep_dim=True) # [[1.9+1.4j], [1.6+2.j]] + + # y is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1, 2], [3, 4]], + # [[5, 6], [7, 8]]] + # Each example is followed by the corresponding output tensor. + y = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + 1j * np.array([[[4, 3], [2, 1]], [[8, 7], [6, 5]]]) + y = dg.to_variable(y) + out5 = cpx.sum(y, dim=[1, 2]) # [10.+10.j, 26.+26.j] + out6 = cpx.sum(y, dim=[0, 1]) # [16.+20.j, 20.+16.j] + + """ + complex_variable_exists([input], "sum") + real = math.sum(input.real, dim=dim, keep_dim=keep_dim, name=name) + imag = math.sum(input.imag, dim=dim, keep_dim=keep_dim, name=name) + return ComplexVariable(real, imag) + + def kron(x, y, name=None): """ The kronecker product of two complex tensors. At least one of inputs :attr:`x` @@ -282,7 +387,7 @@ def kron(x, y, name=None): complex_variable_exists([x, y], "kron") # X = A + Bi, Y = C+Di - # kron(A, B) = kron(A, C) - kron(B, D) + (kron(A, D) + kron(B, C))i + # kron(X, Y) = kron(A, C) - kron(B, D) + (kron(A, D) + kron(B, C))i (a, b) = (x.real, x.imag) if is_complex(x) else (x, None) (c, d) = (y.real, y.imag) if is_complex(y) else (y, None) diff --git a/python/paddle/fluid/tests/unittests/test_complex_sum_layer.py b/python/paddle/fluid/tests/unittests/test_complex_sum_layer.py new file mode 100644 index 00000000000..e8110fd3a96 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_sum_layer.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from numpy.random import random as rand +import paddle.complex as cpx +import paddle.fluid as fluid +import paddle.fluid.dygraph as dg + + +class TestComplexSumLayer(unittest.TestCase): + def setUp(self): + self._dtype = "float64" + self._places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + self._places.append(fluid.CUDAPlace(0)) + + def test_complex_x(self): + input = rand([2, 10, 10]).astype(self._dtype) + 1j * rand( + [2, 10, 10]).astype(self._dtype) + for place in self._places: + with dg.guard(place): + var_x = dg.to_variable(input) + result = cpx.sum(var_x, dim=[1, 2]).numpy() + target = np.sum(input, axis=(1, 2)) + self.assertTrue(np.allclose(result, target)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_complex_trace_layer.py b/python/paddle/fluid/tests/unittests/test_complex_trace_layer.py new file mode 100644 index 00000000000..97123597b6f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_trace_layer.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from numpy.random import random as rand +import paddle.complex as cpx +import paddle.fluid as fluid +import paddle.fluid.dygraph as dg + + +class TestComplexTraceLayer(unittest.TestCase): + def setUp(self): + self._dtype = "float64" + self._places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + self._places.append(fluid.CUDAPlace(0)) + + def test_complex_x(self): + input = rand([2, 20, 2, 3]).astype(self._dtype) + 1j * rand( + [2, 20, 2, 3]).astype(self._dtype) + for place in self._places: + with dg.guard(place): + var_x = dg.to_variable(input) + result = cpx.trace(var_x, offset=1, dim1=0, dim2=2).numpy() + target = np.trace(input, offset=1, axis1=0, axis2=2) + self.assertTrue(np.allclose(result, target)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_trace_op.py b/python/paddle/fluid/tests/unittests/test_trace_op.py new file mode 100644 index 00000000000..5d96d149a08 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_trace_op.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.tensor as tensor + + +class TestTraceOp(OpTest): + def setUp(self): + self.op_type = "trace" + self.init_config() + self.outputs = {'Out': self.target} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Input'], 'Out') + + def init_config(self): + self.case = np.random.randn(20, 6).astype('float64') + self.inputs = {'Input': self.case} + self.attrs = {'offset': 0, 'dim1': 0, 'dim2': 1} + self.target = np.trace(self.inputs['Input']) + + +class TestTraceOpCase1(TestTraceOp): + def init_config(self): + self.case = np.random.randn(2, 20, 2, 3).astype('float32') + self.inputs = {'Input': self.case} + self.attrs = {'offset': 1, 'dim1': 0, 'dim2': 2} + self.target = np.trace( + self.inputs['Input'], + offset=self.attrs['offset'], + axis1=self.attrs['dim1'], + axis2=self.attrs['dim2']) + + +class TestTraceOpCase2(TestTraceOp): + def init_config(self): + self.case = np.random.randn(2, 20, 2, 3).astype('float32') + self.inputs = {'Input': self.case} + self.attrs = {'offset': -5, 'dim1': 1, 'dim2': -1} + self.target = np.trace( + self.inputs['Input'], + offset=self.attrs['offset'], + axis1=self.attrs['dim1'], + axis2=self.attrs['dim2']) + + +class TestTraceAPICase(unittest.TestCase): + def test_case1(self): + case = np.random.randn(2, 20, 2, 3).astype('float32') + data1 = fluid.data(name='data1', shape=[2, 20, 2, 3], dtype='float32') + out1 = tensor.trace(data1) + out2 = tensor.trace(data1, offset=-5, dim1=1, dim2=-1) + + place = core.CPUPlace() + exe = fluid.Executor(place) + results = exe.run(fluid.default_main_program(), + feed={"data1": case}, + fetch_list=[out1, out2], + return_numpy=True) + target1 = np.trace(case) + target2 = np.trace(case, offset=-5, axis1=1, axis2=-1) + self.assertTrue(np.allclose(results[0], target1)) + self.assertTrue(np.allclose(results[1], target2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index a3bbc4879e4..ab8342ec820 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -151,6 +151,7 @@ from .math import erf #DEFINE_ALIAS from .math import addcmul #DEFINE_ALIAS from .math import addmm #DEFINE_ALIAS from .math import clamp #DEFINE_ALIAS +from .math import trace #DEFINE_ALIAS from .math import kron #DEFINE_ALIAS # from .random import gaussin #DEFINE_ALIAS # from .random import uniform #DEFINE_ALIAS diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 8c66dd7005a..ba40d70f59c 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import print_function -from ..fluid.framework import Variable, in_dygraph_mode +from ..fluid.framework import Variable from ..fluid.initializer import Constant from ..fluid.layers import core from ..fluid.layer_helper import LayerHelper diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b099c7f8653..5008b16e8e8 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -14,12 +14,13 @@ """ math functions """ - from __future__ import print_function from paddle.common_ops_import import * from ..fluid import layers -from ..fluid.framework import core, _varbase_creator +from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable +from ..fluid.layer_helper import LayerHelper +from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from ..fluid.layers.layer_function_generator import _generate_doc_string_ import sys @@ -111,6 +112,7 @@ __all__ = [ 'addcmul', 'addmm', 'clamp', + 'trace', 'kron' ] @@ -1520,6 +1522,99 @@ def clamp(input, min=None, max=None, output=None, name=None): return output +def trace(input, offset=0, dim1=0, dim2=1, out=None, name=None): + """ + This OP computes the sum along diagonals of the input tensor. + + If ``input`` is 2D, returns the sum of diagonal. + + If ``input`` has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from + the 2D planes specified by dim1 and dim2. By default, the 2D planes formed by the first and second dimensions + of the input tensor. + + The argument ``offset`` determines where diagonals are taken from input tensor: + + - If offset = 0, it is the main diagonal. + - If offset > 0, it is above the main diagonal. + - If offset < 0, it is below the main diagonal. + + Args: + input(Variable): The input tensor. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64. + offset(int, optional): Which diagonals in input tensor will be taken. Default: 0 (main diagonals). + dim1(int, optional): The first dimension with respect to take diagonal. Default: 0. + dim2(int, optional): The second dimension with respect to take diagonal. Default: 1. + name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. + + Returns: + Variable: the output data type is the same as input data type. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid.dygraph as dg + import numpy as np + + case1 = np.random.randn(2, 3).astype('float32') + case2 = np.random.randn(3, 10, 10).astype('float32') + case3 = np.random.randn(3, 10, 5, 10).astype('float32') + + with dg.guard(): + case1 = dg.to_variable(case1) + case2 = dg.to_variable(case2) + case3 = dg.to_variable(case3) + data1 = paddle.trace(case1) # data1.shape = [1] + data2 = paddle.trace(case2, offset=1, dim1=1, dim2=2) # data2.shape = [3] + data3 = paddle.trace(case3, offset=-3, dim1=1, dim2=-1) # data2.shape = [3, 5] + """ + inputs = {'Input': [input]} + attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} + + def __check_input(input, offset, dim1, dim2): + check_dtype(input.dtype, 'Input', + ['int32', 'int64', 'float16', 'float32', 'float64'], + 'trace') + + input_shape = list(input.shape) + assert len(input_shape) >= 2, \ + "The input must be at least 2-dimensional, " \ + "But received Input's dimensional: %s.\n" % \ + len(input_shape) + + dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1 + dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2 + + assert dim1_ < len(input_shape), \ + "The argument dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \ + % (-(len(input_shape)), len(input_shape) - 1, dim1) + + assert dim2_ < len(input_shape), \ + "The argument dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \ + % (-(len(input_shape)), len(input_shape) - 1, dim2) + + + assert dim1_ != dim2_, \ + "dim1 and dim2 cannot be the same dimension." \ + "But received dim1 = %d, dim2 = %d\n"%(dim1, dim2) + + if not in_dygraph_mode(): + __check_input(input, offset, dim1, dim2) + helper = LayerHelper('trace', **locals()) + + if out is None: + out = helper.create_variable_for_type_inference(dtype=input.dtype) + else: + check_variable_and_dtype(out, 'out', ['float16', 'float32', 'float64', 'int32', 'int64'], 'trace') + + helper.append_op( + type='trace', + inputs={'Input': [input]}, + attrs={'offset': offset, + 'dim1': dim1, + 'dim2': dim2}, + outputs={'Out': [out]}) + return out + @templatedoc(op_type="kron") def kron(x, y, out=None, name=None): """${comment} -- GitLab