From e01262e691dc8d847f5152b84bf0fc12be3cb5db Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Sat, 25 Apr 2020 14:11:01 +0800 Subject: [PATCH] add kron op (#24105) * add kron op and its python API, doc and unittests. * add kron in paddle.complex --- paddle/fluid/operators/kron_op.cc | 168 +++++++++ paddle/fluid/operators/kron_op.cu | 33 ++ paddle/fluid/operators/kron_op.h | 325 ++++++++++++++++++ python/paddle/__init__.py | 1 + python/paddle/complex/tensor/math.py | 85 ++++- .../tests/unittests/test_complex_kron.py | 70 ++++ .../fluid/tests/unittests/test_kron_op.py | 101 ++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/math.py | 62 ++++ 9 files changed, 845 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/kron_op.cc create mode 100644 paddle/fluid/operators/kron_op.cu create mode 100644 paddle/fluid/operators/kron_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_complex_kron.py create mode 100644 python/paddle/fluid/tests/unittests/test_kron_op.py diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc new file mode 100644 index 0000000000..a98d56d6fc --- /dev/null +++ b/paddle/fluid/operators/kron_op.cc @@ -0,0 +1,168 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/operators/kron_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +class KronOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kron"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "kron"); + + auto dim_x = ctx->GetInputDim("X"); + auto dim_y = ctx->GetInputDim("Y"); + auto rank_x = dim_x.size(); + auto rank_y = dim_y.size(); + auto rank = (rank_x > rank_y) ? rank_x : rank_y; + + std::vector dim_out; + dim_out.reserve(rank); + for (int i = 0; i < rank; i++) { + int64_t dim_xi = (i < rank - rank_x) ? 1 : dim_x.at(i - (rank - rank_x)); + int64_t dim_yi = (i < rank - rank_y) ? 1 : dim_y.at(i - (rank - rank_y)); + dim_out.push_back(dim_xi == -1 || dim_yi == -1 ? -1 : dim_xi * dim_yi); + } + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class KronOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), the first operand of kron op"); + AddInput("Y", "(Tensor), the second operand of kron op"); + AddOutput("Out", "(Tensor), the output of kron op."); + AddComment(R"DOC( + Kron Operator. + + This operator computes the Kronecker product of two tensors, a + composite tensor made of blocks of the second tensor scaled by the + first. + + This operator assumes that the rank of the two tensors, $X$ and $Y$ + are the same, if necessary prepending the smallest with ones. If the + shape of $X$ is [$r_0$, $r_1$, ..., $r_N$] and the shape of $Y$ is + [$s_0$, $s_1$, ..., $s_N$], then the shape of the output tensor is + [$r_{0}s_{0}$, $r_{1}s_{1}$, ..., $r_{N}s_{N}$]. The elements are + products of elements from $X$ and $Y$. + + The equation is: + $$ + output[k_{0}, k_{1}, ..., k_{N}] = X[i_{0}, i_{1}, ..., i_{N}] * + Y[j_{0}, j_{1}, ..., j_{N}] + $$ + + where + $$ + k_{t} = i_{t} * s_{t} + j_{t}, t = 0, 1, ..., N + $$ + )DOC"); + } +}; + +class KronGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kron_grad"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "kron_grad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "kron_grad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Y")), "Output", + framework::GradVarName("Y"), "kron_grad"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ x_grad_name); + ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); + ctx->ShareLoD("Y", /*->*/ y_grad_name); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), + ctx.GetPlace()); + } +}; + +template +class KronGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("kron_grad"); + + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Y", this->Input("Y")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + + grad_op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(kron, ops::KronOp, ops::KronOpMaker, + ops::KronGradOpMaker, + ops::KronGradOpMaker); +REGISTER_OP_CPU_KERNEL( + kron, ops::KronKernel, + ops::KronKernel, + ops::KronKernel, + ops::KronKernel, + ops::KronKernel); + +REGISTER_OPERATOR(kron_grad, ops::KronGradOp); +REGISTER_OP_CPU_KERNEL( + kron_grad, ops::KronGradKernel, + ops::KronGradKernel, + ops::KronGradKernel, + ops::KronGradKernel, + ops::KronGradKernel); diff --git a/paddle/fluid/operators/kron_op.cu b/paddle/fluid/operators/kron_op.cu new file mode 100644 index 0000000000..02eeefeabb --- /dev/null +++ b/paddle/fluid/operators/kron_op.cu @@ -0,0 +1,33 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/kron_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + kron, ops::KronKernel, + ops::KronKernel, + ops::KronKernel, + ops::KronKernel, + ops::KronKernel); + +REGISTER_OP_CUDA_KERNEL( + kron_grad, ops::KronGradKernel, + ops::KronGradKernel, + ops::KronGradKernel, + ops::KronGradKernel, + ops::KronGradKernel); diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h new file mode 100644 index 0000000000..ec7a8a7d9b --- /dev/null +++ b/paddle/fluid/operators/kron_op.h @@ -0,0 +1,325 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" +#if __NVCC__ +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "thrust/device_vector.h" +#endif + +namespace paddle { +namespace operators { + +// Process an element in the output, used with a parallel-for +template +struct KronElemFunctor { + KronElemFunctor(const T* a, const T* b, T* out, const int64_t* shape_b, + const int64_t* stride_a, const int64_t* stride_b, + const int64_t* stride_out, int ndims) + : a_(a), + b_(b), + out_(out), + shape_b_(shape_b), + stride_a_(stride_a), + stride_b_(stride_b), + stride_out_(stride_out), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) const { + // it computes 1 element in the output + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_out_[i]; + index = index % stride_out_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + out_[idx] = a_[index_a] * b_[index_b]; + } + + private: + const T* a_; + const T* b_; + T* out_; + const int64_t* shape_b_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* stride_out_; + const int ndims_; +}; + +template +struct KronOpFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor& x, + const framework::Tensor& y, framework::Tensor* out) { + int ndims = out->dims().size(); + int64_t numel = out->numel(); + + const framework::DDim& dim_x = x.dims(); + const framework::DDim& dim_y = y.dims(); + const framework::DDim& dim_out = out->dims(); + const framework::DDim stride_x = framework::stride(dim_x); + const framework::DDim stride_y = framework::stride(dim_y); + const framework::DDim stride_out = framework::stride(dim_out); + + const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr, + *p_stride_out = nullptr, *p_shape_y = nullptr; +#if __NVCC__ + thrust::device_vector d_stride_x(ndims); + thrust::device_vector d_stride_y(ndims); + thrust::device_vector d_stride_out(ndims); + thrust::device_vector d_shape_y(ndims); + thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); + thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); + thrust::copy(stride_out.Get(), stride_out.Get() + ndims, + d_stride_out.begin()); + thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); + + p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); + p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); + p_stride_out = thrust::raw_pointer_cast(d_stride_out.data()); + p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); +#else + p_stride_x = stride_x.Get(); + p_stride_y = stride_y.Get(); + p_stride_out = stride_out.Get(); + p_shape_y = dim_y.Get(); +#endif + + platform::ForRange for_range(dev_ctx, numel); + KronElemFunctor functor(x.data(), y.data(), out->data(), + p_shape_y, p_stride_x, p_stride_y, p_stride_out, + ndims); + for_range(functor); + } +}; + +template +struct KronGradElemFunctor { + KronGradElemFunctor(const T* dout, const T* A, const T* B, T* dout_a, + T* dout_b, const int64_t* stride_dout, + const int64_t* stride_a, const int64_t* stride_b, + const int64_t* shape_b, const int64_t numel_a, + const int64_t numel_b, const int ndims) + : dout_(dout), + A_(A), + B_(B), + dout_a_(dout_a), + dout_b_(dout_b), + stride_dout_(stride_dout), + stride_a_(stride_a), + stride_b_(stride_b), + shape_b_(shape_b), + numel_a_(numel_a), + numel_b_(numel_b), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) { + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_dout_[i]; + index = index % stride_dout_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + + size_t index_out_a = index_a * numel_b_ + index_b; + size_t index_out_b = index_b * numel_a_ + index_a; + + dout_a_[index_out_a] = dout_[idx] * B_[index_b]; + dout_b_[index_out_b] = dout_[idx] * A_[index_a]; + } + + private: + const T* dout_; + const T* A_; + const T* B_; + T* dout_a_; + T* dout_b_; + const int64_t* stride_dout_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* shape_b_; + const int64_t numel_a_; + const int64_t numel_b_; + const int ndims_; +}; + +template +struct IdentityFunctor { + HOSTDEVICE explicit inline IdentityFunctor() {} + + HOSTDEVICE inline T operator()(const T& x) const { return x; } +}; + +template +struct KronGradOpFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor& dout, + const framework::Tensor& x, const framework::Tensor& y, + framework::Tensor* dx, framework::Tensor* dy) { + int ndims = dout.dims().size(); + int64_t numel = dout.numel(); + int64_t numel_x = x.numel(); + int64_t numel_y = y.numel(); + + const framework::DDim& dim_x = x.dims(); + const framework::DDim& dim_y = y.dims(); + const framework::DDim& dim_dout = dout.dims(); + + const framework::DDim stride_x = framework::stride(dim_x); + const framework::DDim stride_y = framework::stride(dim_y); + const framework::DDim stride_dout = framework::stride(dim_dout); + + const int64_t* p_stride_x = nullptr; + const int64_t* p_stride_y = nullptr; + const int64_t* p_stride_dout = nullptr; + const int64_t* p_shape_y = nullptr; +#if __NVCC__ + thrust::device_vector d_stride_x(ndims); + thrust::device_vector d_stride_y(ndims); + thrust::device_vector d_stride_dout(ndims); + thrust::device_vector d_shape_y(ndims); + thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); + thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); + thrust::copy(stride_dout.Get(), stride_dout.Get() + ndims, + d_stride_dout.begin()); + thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); + + p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); + p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); + p_stride_dout = thrust::raw_pointer_cast(d_stride_dout.data()); + p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); +#else + p_stride_x = stride_x.Get(); + p_stride_y = stride_y.Get(); + p_stride_dout = stride_dout.Get(); + p_shape_y = dim_y.Get(); +#endif + // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) + // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) + framework::Tensor dout_x; + dout_x.mutable_data({numel_x, numel_y}, dev_ctx.GetPlace()); + framework::Tensor dout_y; + dout_y.mutable_data({numel_y, numel_x}, dev_ctx.GetPlace()); + + platform::ForRange for_range(dev_ctx, numel); + KronGradElemFunctor func(dout.data(), x.data(), y.data(), + dout_x.data(), dout_y.data(), + p_stride_dout, p_stride_x, p_stride_y, + p_shape_y, numel_x, numel_y, ndims); + for_range(func); + +// reduce_sum along aixs 1 +#if __NVCC__ + auto stream = dev_ctx.stream(); // it is a cuda device_context + TensorReduce>( + dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + stream); + TensorReduce>( + dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + stream); +#else + auto eigen_dout_x = framework::EigenMatrix::Reshape(dout_x, 1); + auto eigen_dout_y = framework::EigenMatrix::Reshape(dout_y, 1); + auto eigen_vec_dx = framework::EigenVector::Flatten(*dx); + auto eigen_vec_dy = framework::EigenVector::Flatten(*dy); + auto* place = dev_ctx.eigen_device(); + Eigen::array reduce_dim = {1}; + eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); + eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); +#endif + } +}; + +inline framework::Tensor UnsqueezeTo(const framework::Tensor& src, int ndims) { + const framework::DDim& shape = src.dims(); + int rank = shape.size(); + framework::Tensor res; + res.ShareDataWith(src); + PADDLE_ENFORCE_LE( + rank, ndims, + platform::errors::InvalidArgument( + "The input Tensor's rank should be less than or equal to ndims" + "Received input Tensor's rank = %d, ndims = %d", + rank, ndims)); + if (rank < ndims) { + std::vector new_dim(ndims, 1); + for (int i = ndims - rank; i < ndims; i++) { + new_dim[i] = shape[i - ndims + rank]; + } + res.Resize(framework::make_ddim(new_dim)); + } + return res; +} + +template +class KronKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + int ndims = out->dims().size(); + framework::Tensor xx = UnsqueezeTo(*x, ndims); + framework::Tensor yy = UnsqueezeTo(*y, ndims); + + KronOpFunctor func; + func(dev_ctx, xx, yy, out); + } +}; + +template +class KronGradKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + dx->mutable_data(ctx.GetPlace()); + dy->mutable_data(ctx.GetPlace()); + + int ndims = dout->dims().size(); + framework::Tensor xx = UnsqueezeTo(*x, ndims); + framework::Tensor dxx = UnsqueezeTo(*dx, ndims); + framework::Tensor yy = UnsqueezeTo(*y, ndims); + framework::Tensor dyy = UnsqueezeTo(*dy, ndims); + + KronGradOpFunctor func; + func(dev_ctx, *dout, xx, yy, &dxx, &dyy); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 6020ed68b5..8abd1eb7ee 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -139,6 +139,7 @@ from .tensor.math import min #DEFINE_ALIAS from .tensor.math import mm #DEFINE_ALIAS from .tensor.math import div #DEFINE_ALIAS from .tensor.math import add #DEFINE_ALIAS +from .tensor.math import kron #DEFINE_ALIAS # from .tensor.math import atan #DEFINE_ALIAS from .tensor.math import logsumexp #DEFINE_ALIAS # from .tensor.math import inverse #DEFINE_ALIAS diff --git a/python/paddle/complex/tensor/math.py b/python/paddle/complex/tensor/math.py index ba0bc77620..e21cdf7620 100644 --- a/python/paddle/complex/tensor/math.py +++ b/python/paddle/complex/tensor/math.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle.common_ops_import import * from ..helper import is_complex, is_real, complex_variable_exists from ...fluid.framework import ComplexVariable from ...fluid import layers +from ...tensor import math __all__ = [ - 'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div' + 'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div', + 'kron' ] @@ -37,6 +40,9 @@ def elementwise_add(x, y, axis=-1, name=None): with any number of dimensions. The supported data types include float32 and float64 when it is a Variable. Otherwise the supported data types are complex64 or complex128. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. Examples: .. code-block:: python @@ -85,6 +91,9 @@ def elementwise_sub(x, y, axis=-1, name=None): with any number of dimensions. The supported data types include float32 and float64 when it is a Variable. Otherwise the supported data types are complex64 or complex128. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. Examples: .. code-block:: python @@ -133,6 +142,9 @@ def elementwise_mul(x, y, axis=-1, name=None): with any number of dimensions. The supported data types include float32 and float64 when it is a Variable. Otherwise the supported data types are complex64 or complex128. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. Examples: .. code-block:: python @@ -184,6 +196,9 @@ def elementwise_div(x, y, axis=-1, name=None): with any number of dimensions. The supported data types include float32 and float64 when it is a Variable. Otherwise the supported data types are complex64 or complex128. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. Examples: .. code-block:: python @@ -214,3 +229,71 @@ def elementwise_div(x, y, axis=-1, name=None): e, axis=axis, name=name) + + +def kron(x, y, name=None): + """ + The kronecker product of two complex tensors. At least one of inputs :attr:`x` + and :attr:`y` must be a ComplexVariable. See the detailed description for + the function and other arguments in :ref:`api_paddle_tensor_kron` . + + Let $x = a + ib$, and $y = c + id$, the euqation is + + .. math:: + kron(x, y) = kron(a, c) - kron(b, d) + i(kron(a, d) + kron(b, c)) + + Args: + x (Variable|ComplexVariable): The first input Variable or ComplexVariable + with any number of dimensions. The supported data types include float32 + and float64 when it is a Variable. Otherwise the supported data types + are complex64 or complex128. + y (Variable|ComplexVariable): The second input Variable or ComplexVariable + with any number of dimensions. The supported data types include float32 + and float64 when it is a Variable. Otherwise the supported data types + are complex64 or complex128. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + ComplexVariable: The kronecker product, data type: complex64 or complex128, depending on the data type of x and y. If the data types of x and y are float32/complex64, the data type of the output is complex64, else if the data types of x and y are float64/complex128, the data type of the output is complex128. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.fluid.dygraph as dg + + a = np.array([[1.0+1.0j, 2.0+1.0j], [3.0+1.0j, 4.0+1.0j]]) + b = np.array([[5.0+2.0j, 6.0+2.0j], [7.0+2.0j, 8.0+2.0j]]) + + place = fluid.CPUPlace() + with dg.guard(place): + x = dg.to_variable(a) + y = dg.to_variable(b) + out = paddle.complex.kron(x, y) + print(out.numpy()) + # [[ 3. +7.j 4. +8.j 8. +9.j 10.+10.j] + # [ 5. +9.j 6.+10.j 12.+11.j 14.+12.j] + # [13.+11.j 16.+12.j 18.+13.j 22.+14.j] + # [19.+13.j 22.+14.j 26.+15.j 30.+16.j]] + """ + complex_variable_exists([x, y], "kron") + + # X = A + Bi, Y = C+Di + # kron(A, B) = kron(A, C) - kron(B, D) + (kron(A, D) + kron(B, C))i + (a, b) = (x.real, x.imag) if is_complex(x) else (x, None) + (c, d) = (y.real, y.imag) if is_complex(y) else (y, None) + + if is_real(b) and is_real(d): + real = math.kron(a, c) - math.kron(b, d) + imag = math.kron(a, d) + math.kron(b, c) + elif is_real(b): + real = math.kron(a, c) + imag = math.kron(b, c) + else: + # is_real(d) + real = math.kron(a, c) + imag = math.kron(a, d) + return ComplexVariable(real, imag) diff --git a/python/paddle/fluid/tests/unittests/test_complex_kron.py b/python/paddle/fluid/tests/unittests/test_complex_kron.py new file mode 100644 index 0000000000..50817a8734 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_kron.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import fluid, tensor +import paddle.complex as cpx +import paddle.fluid.dygraph as dg +import numpy as np +import unittest + + +class ComplexKronTestCase(unittest.TestCase): + def __init__(self, methodName='runTest', x=None, y=None): + super(ComplexKronTestCase, self).__init__(methodName) + self.x = x + self.y = y + + def setUp(self): + self.ref_result = np.kron(self.x, self.y) + + def runTest(self): + place = fluid.CPUPlace() + self.test_identity(place) + + if fluid.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self.test_identity(place) + + def test_identity(self, place): + with dg.guard(place): + x_var = dg.to_variable(self.x) + y_var = dg.to_variable(self.y) + out_var = cpx.kron(x_var, y_var) + np.testing.assert_allclose(out_var.numpy(), self.ref_result) + + +def load_tests(loader, standard_tests, pattern): + suite = unittest.TestSuite() + suite.addTest( + ComplexKronTestCase( + x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2), + y=np.random.randn(3, 3) + 1j * np.random.randn(3, 3))) + suite.addTest( + ComplexKronTestCase( + x=np.random.randn(2, 2), + y=np.random.randn(3, 3) + 1j * np.random.randn(3, 3))) + suite.addTest( + ComplexKronTestCase( + x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2), + y=np.random.randn(3, 3))) + + suite.addTest( + ComplexKronTestCase( + x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2), + y=np.random.randn(2, 2, 3))) + return suite + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_kron_op.py b/python/paddle/fluid/tests/unittests/test_kron_op.py new file mode 100644 index 0000000000..57076b7551 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_kron_op.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dg + + +class TestKronOp(OpTest): + def setUp(self): + self.op_type = "kron" + self.dtype = self._init_dtype() + x = np.random.uniform(size=(10, 10)).astype(self.dtype) + y = np.random.uniform(size=(10, 10)).astype(self.dtype) + out_ref = np.kron(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out_ref} + + def _init_dtype(self): + return "float64" + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'Y'], 'Out') + + +class TestKronOp2(TestKronOp): + def setUp(self): + self.op_type = "kron" + self.dtype = self._init_dtype() + x = np.random.uniform(size=(5, 5, 4)).astype(self.dtype) + y = np.random.uniform(size=(10, 10)).astype(self.dtype) + out_ref = np.kron(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out_ref} + + +class TestKronOp3(TestKronOp): + def setUp(self): + self.op_type = "kron" + self.dtype = self._init_dtype() + x = np.random.uniform(size=(10, 10)).astype(self.dtype) + y = np.random.uniform(size=(5, 5, 4)).astype(self.dtype) + out_ref = np.kron(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out_ref} + + +class TestKronLayer(unittest.TestCase): + def test_case(self): + a = np.random.randn(10, 10).astype(np.float64) + b = np.random.randn(10, 10).astype(np.float64) + + place = fluid.CPUPlace() + with dg.guard(place): + a_var = dg.to_variable(a) + b_var = dg.to_variable(b) + c_var = paddle.kron(a_var, b_var) + np.testing.assert_allclose(c_var.numpy(), np.kron(a, b)) + + def test_case_with_output(self): + a = np.random.randn(10, 10).astype(np.float64) + b = np.random.randn(10, 10).astype(np.float64) + + main = fluid.Program() + start = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, start): + a_var = fluid.data("a", [-1, -1], dtype="float64") + b_var = fluid.data("b", [-1, -1], dtype="float64") + out_var = fluid.layers.create_tensor("float64", "c") + paddle.kron(a_var, b_var, out=out_var) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(start) + c, = exe.run(main, feed={'a': a, 'b': b}, fetch_list=[out_var]) + np.testing.assert_allclose(c, np.kron(a, b)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 4ce7725b3a..d8f76a8e39 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -117,6 +117,7 @@ from .math import min #DEFINE_ALIAS from .math import mm #DEFINE_ALIAS from .math import div #DEFINE_ALIAS from .math import add #DEFINE_ALIAS +from .math import kron #DEFINE_ALIAS # from .math import atan #DEFINE_ALIAS from .math import logsumexp #DEFINE_ALIAS # from .math import inverse #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e99cac9a91..8936d79723 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -79,6 +79,7 @@ __all__ = [ 'addcmul', 'addmm', 'clamp', + 'kron', ] # yapf: enable. @@ -1412,3 +1413,64 @@ def clamp(input, min=None, max=None, output=None, name=None): type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs) return output + +@templatedoc(op_type="kron") +def kron(x, y, out=None, name=None): + """${comment} + + Args: + x (Variable): the fist operand of kron op, data type: float16, float32, + float64, int32 or int64. + y (Variable): the second operand of kron op, data type: float16, + float32, float64, int32 or int64. Its data type should be the same + with x. + out (Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of + operation. If out is None, a new Varibale will be create to store + the result. Defaults to None. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Variable: The output of kron op, data type: float16, float32, float64, int32 or int64. Its data is the same with x. + + Examples: + .. code-block:: python + + import paddle + from paddle import fluid + import paddle.fluid.dygraph as dg + import numpy as np + + a = np.arange(1, 5).reshape(2, 2).astype(np.float32) + b = np.arange(1, 10).reshape(3, 3).astype(np.float32) + + place = fluid.CPUPlace() + with dg.guard(place): + a_var = dg.to_variable(a) + b_var = dg.to_variable(b) + c_var = paddle.kron(a_var, b_var) + c_np = c_var.numpy() + print(c_np) + + #[[ 1. 2. 3. 2. 4. 6.] + # [ 4. 5. 6. 8. 10. 12.] + # [ 7. 8. 9. 14. 16. 18.] + # [ 3. 6. 9. 4. 8. 12.] + # [12. 15. 18. 16. 20. 24.] + # [21. 24. 27. 28. 32. 36.]] + """ + if in_dygraph_mode(): + return core.ops.kron(x, y) + + helper = LayerHelper('kron', **locals()) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron') + check_variable_and_dtype(y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron') + + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + check_variable_and_dtype(out, 'out', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron') + helper.append_op(type="kron", inputs={"X": x, "Y": y}, outputs={"Out": out}) + return out -- GitLab