From 7346edc21de3f8e0c87af0b963753c35f135f29f Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 12 Jan 2021 10:26:34 +0800 Subject: [PATCH] [Cherry-pick] Complex grad for matmul, kron and type promotion (#30304) * complex gradient matmul (#29966) * dot op support complex types * matmul support complex types * add test case * matmul broadcast gradient support complex * move conjFunctor to complex_functor.h * change the kron gradient when complex types (#29995) * type promotion for grad (#30177) * type promotion for grad * add type promotion for div op --- paddle/fluid/operators/conj_op.h | 40 +-- paddle/fluid/operators/dot_op.cc | 12 +- paddle/fluid/operators/dot_op.cu | 23 +- paddle/fluid/operators/dot_op.h | 256 ++++++++++++++---- .../elementwise/elementwise_div_op.h | 16 +- .../operators/elementwise/elementwise_op.h | 26 ++ paddle/fluid/operators/kron_op.cc | 13 + paddle/fluid/operators/kron_op.h | 125 +++++++++ .../fluid/operators/math/complex_functors.h | 37 +++ paddle/fluid/operators/matmul_v2_op.cc | 21 ++ paddle/fluid/operators/matmul_v2_op.h | 104 +++++-- .../fluid/tests/unittests/test_dot_op.py | 122 +++++++++ .../unittests/test_elementwise_div_op.py | 15 + .../unittests/test_elementwise_mul_op.py | 15 + .../unittests/test_elementwise_sub_op.py | 74 +++++ .../fluid/tests/unittests/test_kron_op.py | 99 +++++++ .../tests/unittests/test_matmul_v2_op.py | 136 ++++++++++ .../white_list/no_grad_set_white_list.py | 1 + 18 files changed, 1009 insertions(+), 126 deletions(-) diff --git a/paddle/fluid/operators/conj_op.h b/paddle/fluid/operators/conj_op.h index 0bec7b707e3..417a136c60b 100644 --- a/paddle/fluid/operators/conj_op.h +++ b/paddle/fluid/operators/conj_op.h @@ -17,49 +17,13 @@ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EnableComplex = - typename std::enable_if::value || - std::is_same::value>::type; - -template -using DisableComplex = typename std::enable_if< - !std::is_same::value && - !std::is_same::value>::type; - -template -struct ConjFunctor; - -template -struct ConjFunctor> { - ConjFunctor(const T* input, int64_t numel, T* output) - : input_(input), numel_(numel), output_(output) {} - - HOSTDEVICE void operator()(size_t idx) const { - output_[idx] = T(input_[idx].real, -input_[idx].imag); - } - const T* input_; - int64_t numel_; - T* output_; -}; - -template -struct ConjFunctor> { - ConjFunctor(const T* input, int64_t numel, T* output) - : input_(input), numel_(numel), output_(output) {} - - HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; } - const T* input_; - int64_t numel_; - T* output_; -}; - template class ConjKernel : public framework::OpKernel { public: @@ -74,7 +38,7 @@ class ConjKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); platform::ForRange for_range(dev_ctx, numel); - ConjFunctor functor(x_data, numel, out_data); + math::ConjFunctor functor(x_data, numel, out_data); for_range(functor); } }; diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index 0527445adf0..26f12e8f9e3 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -152,9 +152,17 @@ REGISTER_OP_CPU_KERNEL( dot, ops::DotKernel, ops::DotKernel, ops::DotKernel, - ops::DotKernel); + ops::DotKernel, + ops::DotKernel, + ops::DotKernel); REGISTER_OP_CPU_KERNEL( dot_grad, ops::DotGradKernel, ops::DotGradKernel, ops::DotGradKernel, - ops::DotGradKernel); + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel); diff --git a/paddle/fluid/operators/dot_op.cu b/paddle/fluid/operators/dot_op.cu index eb7ebbe32d7..2d259ba1fbc 100644 --- a/paddle/fluid/operators/dot_op.cu +++ b/paddle/fluid/operators/dot_op.cu @@ -17,12 +17,17 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(dot, ops::DotKernel, - ops::DotKernel, - ops::DotKernel, - ops::DotKernel); -REGISTER_OP_CUDA_KERNEL(dot_grad, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel); +REGISTER_OP_CUDA_KERNEL( + dot, ops::DotKernel, + ops::DotKernel, + ops::DotKernel, + ops::DotKernel, + ops::DotKernel, + ops::DotKernel); +REGISTER_OP_CUDA_KERNEL( + dot_grad, ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel); diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index cec706300d7..c78ac87084c 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -16,95 +16,233 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using complex64 = platform::complex64; +using complex128 = platform::complex128; template using EigenMatrix = framework::EigenMatrix; +template +struct P { + void operator()(T a, R b); +}; + +template +struct DotGradFunction { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + const Tensor* tensor_dout, Tensor* tensor_dx, + Tensor* tensor_dy, + const paddle::framework::ExecutionContext& ctx); +}; + template -void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { +struct DotGradFunction> { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + const Tensor* tensor_dout, Tensor* tensor_dx, + Tensor* tensor_dy, + const paddle::framework::ExecutionContext& ctx) { #ifdef __NVCC__ - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); + if (1 == tensor_dout->dims().size()) { + auto dout = framework::EigenVector::Flatten(*tensor_dout); - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto& dev = *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - dx.device(dev) = y * dout.broadcast(size); - } + if (tensor_dx) { + auto y = framework::EigenVector::Flatten(*tensor_y); + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto& dev = *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - dy.device(dev) = x * dout.broadcast(size); + paddle::platform::ForRange for_range(dev_raw, + tensor_y->numel()); + math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), + tensor_dx->data()); + for_range(functor); + auto dx = framework::EigenVector::Flatten(*tensor_dx); + + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = framework::EigenVector::Flatten(*tensor_x); + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + + paddle::platform::ForRange for_range(dev_raw, + tensor_y->numel()); + math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), + tensor_dy->data()); + for_range(functor); + auto dy = framework::EigenVector::Flatten(*tensor_dy); + + dy.device(dev) = dy * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(ctx.GetPlace()); + auto y = EigenMatrix::From(*tensor_y); + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + + paddle::platform::ForRange for_range(dev_raw, + tensor_y->numel()); + math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), + tensor_dx->data()); + for_range(functor); + auto dx = EigenMatrix::From(*tensor_dx); + + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(ctx.GetPlace()); + auto x = EigenMatrix::From(*tensor_x); + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + + paddle::platform::ForRange for_range(dev_raw, + tensor_x->numel()); + math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), + tensor_dy->data()); + for_range(functor); + + auto dy = EigenMatrix::From(*tensor_dy); + + dy.device(dev) = dy * dout.broadcast(size); + } } - } else { - auto dout = EigenMatrix::From(*tensor_dout); +#else + const auto* data_dout = tensor_dout->data(); if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = EigenMatrix::From(*tensor_y); - auto dx = EigenMatrix::From(*tensor_dx); - auto& dev = *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - dx.device(dev) = y * dout.broadcast(size); + auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); + const auto* data_y = tensor_y->data(); + const framework::DDim& dim = tensor_x->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s]; + } } if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = EigenMatrix::From(*tensor_x); - auto dy = EigenMatrix::From(*tensor_dy); - auto& dev = *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - dy.device(dev) = x * dout.broadcast(size); + auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); + const auto* data_x = tensor_x->data(); + const framework::DDim& dim = tensor_y->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s]; + } } +#endif } +}; + +template +struct DotGradFunction> { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + const Tensor* tensor_dout, Tensor* tensor_dx, + Tensor* tensor_dy, + const paddle::framework::ExecutionContext& ctx) { +#ifdef __NVCC__ + if (1 == tensor_dout->dims().size()) { + auto dout = framework::EigenVector::Flatten(*tensor_dout); + + if (tensor_dx) { + auto y = framework::EigenVector::Flatten(*tensor_y); + auto dx = framework::EigenVector::Flatten(*tensor_dx); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = framework::EigenVector::Flatten(*tensor_x); + auto dy = framework::EigenVector::Flatten(*tensor_dy); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + dy.device(dev) = x * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(ctx.GetPlace()); + auto y = EigenMatrix::From(*tensor_y); + auto dx = EigenMatrix::From(*tensor_dx); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(ctx.GetPlace()); + auto x = EigenMatrix::From(*tensor_x); + auto dy = EigenMatrix::From(*tensor_dy); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + dy.device(dev) = x * dout.broadcast(size); + } + } #else - const auto* data_dout = tensor_dout->data(); + const auto* data_dout = tensor_dout->data(); - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_y = tensor_y->data(); - const framework::DDim& dim = tensor_x->dims(); - size_t N = static_cast(framework::product(dim)); + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); + const auto* data_y = tensor_y->data(); + const framework::DDim& dim = tensor_x->dims(); + size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; + auto step = dim[dim.size() - 1]; - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = data_y[i] * data_dout[s]; + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = data_y[i] * data_dout[s]; + } } - } - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_x = tensor_x->data(); - const framework::DDim& dim = tensor_y->dims(); - size_t N = static_cast(framework::product(dim)); + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); + const auto* data_x = tensor_x->data(); + const framework::DDim& dim = tensor_y->dims(); + size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; + auto step = dim[dim.size() - 1]; - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = data_x[i] * data_dout[s]; + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = data_x[i] * data_dout[s]; + } } - } #endif -} + } +}; template class DotKernel : public framework::OpKernel { @@ -165,8 +303,8 @@ class DotGradKernel : public framework::OpKernel { if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); - DotGradFunction(tensor_x, tensor_y, tensor_dout, - tensor_dx, tensor_dy, ctx); + DotGradFunction()(tensor_x, tensor_y, tensor_dout, + tensor_dx, tensor_dy, ctx); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index d824014713d..b6f6151e133 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" @@ -203,7 +204,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx)) { @@ -214,6 +215,19 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index d799abf92d9..f426a54f794 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -288,6 +288,19 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { @@ -325,6 +338,19 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; class ElementwiseOpDoubleGradWithoutDXDY diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index db25d05c6b2..dab9948edc3 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -134,6 +134,19 @@ class KronGradOp : public framework::OperatorWithKernel { OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index 62762f3f049..2af3716ae43 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -26,6 +26,9 @@ limitations under the License. */ namespace paddle { namespace operators { +using complex64 = paddle::platform::complex64; +using complex128 = paddle::platform::complex128; + // Process an element in the output, used with a parallel-for template struct KronElemFunctor { @@ -172,6 +175,128 @@ struct KronGradElemFunctor { const int ndims_; }; +template <> +struct KronGradElemFunctor { + KronGradElemFunctor(const complex64* dout, const complex64* A, + const complex64* B, complex64* dout_a, complex64* dout_b, + const int64_t* stride_dout, const int64_t* stride_a, + const int64_t* stride_b, const int64_t* shape_b, + const int64_t numel_a, const int64_t numel_b, + const int ndims) + : dout_(dout), + A_(A), + B_(B), + dout_a_(dout_a), + dout_b_(dout_b), + stride_dout_(stride_dout), + stride_a_(stride_a), + stride_b_(stride_b), + shape_b_(shape_b), + numel_a_(numel_a), + numel_b_(numel_b), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) { + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_dout_[i]; + index = index % stride_dout_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + + if (dout_a_) { + size_t index_out_a = index_a * numel_b_ + index_b; + dout_a_[index_out_a] = + dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag); + } + if (dout_b_) { + size_t index_out_b = index_b * numel_a_ + index_a; + dout_b_[index_out_b] = + dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag); + } + } + + private: + const complex64* dout_; + const complex64* A_; + const complex64* B_; + complex64* dout_a_; + complex64* dout_b_; + const int64_t* stride_dout_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* shape_b_; + const int64_t numel_a_; + const int64_t numel_b_; + const int ndims_; +}; + +template <> +struct KronGradElemFunctor { + KronGradElemFunctor(const complex128* dout, const complex128* A, + const complex128* B, complex128* dout_a, + complex128* dout_b, const int64_t* stride_dout, + const int64_t* stride_a, const int64_t* stride_b, + const int64_t* shape_b, const int64_t numel_a, + const int64_t numel_b, const int ndims) + : dout_(dout), + A_(A), + B_(B), + dout_a_(dout_a), + dout_b_(dout_b), + stride_dout_(stride_dout), + stride_a_(stride_a), + stride_b_(stride_b), + shape_b_(shape_b), + numel_a_(numel_a), + numel_b_(numel_b), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) { + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_dout_[i]; + index = index % stride_dout_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + + if (dout_a_) { + size_t index_out_a = index_a * numel_b_ + index_b; + dout_a_[index_out_a] = + dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag); + } + if (dout_b_) { + size_t index_out_b = index_b * numel_a_ + index_a; + dout_b_[index_out_b] = + dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag); + } + } + + private: + const complex128* dout_; + const complex128* A_; + const complex128* B_; + complex128* dout_a_; + complex128* dout_b_; + const int64_t* stride_dout_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* shape_b_; + const int64_t numel_a_; + const int64_t numel_b_; + const int ndims_; +}; + template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} diff --git a/paddle/fluid/operators/math/complex_functors.h b/paddle/fluid/operators/math/complex_functors.h index 302e3d562c6..18a003d5c9a 100644 --- a/paddle/fluid/operators/math/complex_functors.h +++ b/paddle/fluid/operators/math/complex_functors.h @@ -135,6 +135,43 @@ struct ImagToComplexFunctor>> { int64_t numel_; }; +template +using EnableComplex = + typename std::enable_if::value || + std::is_same::value>::type; + +template +using DisableComplex = typename std::enable_if< + !std::is_same::value && + !std::is_same::value>::type; + +template +struct ConjFunctor; + +template +struct ConjFunctor> { + ConjFunctor(const T* input, int64_t numel, T* output) + : input_(input), numel_(numel), output_(output) {} + + HOSTDEVICE void operator()(size_t idx) const { + output_[idx] = T(input_[idx].real, -input_[idx].imag); + } + const T* input_; + int64_t numel_; + T* output_; +}; + +template +struct ConjFunctor> { + ConjFunctor(const T* input, int64_t numel, T* output) + : input_(input), numel_(numel), output_(output) {} + + HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; } + const T* input_; + int64_t numel_; + T* output_; +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 7a3db793184..6fccd3657af 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -150,6 +150,27 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { context->SetOutputDim(y_grad_name, y_dims); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 129ea7c156d..f313fdbfbf7 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #ifdef __NVCC__ @@ -439,6 +440,61 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } +template +struct ConjHelper { + explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} + HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { + dst.Resize(src.dims()); + dst.set_layout(src.layout()); + dst.ShareDataWith(src); + return; + } + + const framework::ExecutionContext& ctx_; +}; + +template +struct ConjHelper { + explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} + + HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { + dst.Resize(src.dims()); + auto* src_data = src.data(); + auto* dst_data = dst.mutable_data( + ctx_.GetPlace(), + size_t(src.numel() * sizeof(paddle::platform::complex64))); + + platform::ForRange for_range( + ctx_.template device_context(), src.numel()); + math::ConjFunctor functor( + src_data, src.numel(), dst_data); + for_range(functor); + return; + } + const framework::ExecutionContext& ctx_; +}; + +template +struct ConjHelper { + explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} + + HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { + dst.Resize(src.dims()); + auto* src_data = src.data(); + auto* dst_data = dst.mutable_data( + ctx_.GetPlace(), + size_t(src.numel() * sizeof(paddle::platform::complex128))); + + platform::ForRange for_range( + ctx_.template device_context(), src.numel()); + math::ConjFunctor functor( + src_data, src.numel(), dst_data); + for_range(functor); + return; + } + const framework::ExecutionContext& ctx_; +}; + template class MatMulV2GradKernel : public framework::OpKernel { public: @@ -490,6 +546,8 @@ class MatMulV2GradKernel : public framework::OpKernel { auto x = *ctx.Input("X"); auto y = *ctx.Input("Y"); auto dout = *ctx.Input(framework::GradVarName("Out")); + framework::Tensor y_conj(y.type()); + framework::Tensor x_conj(y.type()); // get dims std::vector x_dims = vectorize(x.dims()); @@ -508,7 +566,7 @@ class MatMulV2GradKernel : public framework::OpKernel { if (dx) dx->mutable_data(ctx.GetPlace()); if (dy) dy->mutable_data(ctx.GetPlace()); if (dout.numel() == 1) { - DotGradFunction(&x, &y, &dout, dx, dy, ctx); + DotGradFunction()(&x, &y, &dout, dx, dy, ctx); return; } } @@ -533,6 +591,10 @@ class MatMulV2GradKernel : public framework::OpKernel { if (dx_dims != x.dims()) { dx->Resize(x.dims()); } + + // for complex + ConjHelper conj_helper(ctx); + conj_helper(y, y_conj); } framework::DDim dy_dims; @@ -541,19 +603,23 @@ class MatMulV2GradKernel : public framework::OpKernel { if (dy_dims != y.dims()) { dy->Resize(y.dims()); } + + // for complex + ConjHelper conj_helper(ctx); + conj_helper(x, x_conj); } if (transpose_x && transpose_y) { - CalcInputGrad(ctx, y, true, true, dout, true, false, dx); - CalcInputGrad(ctx, dout, true, true, x, true, false, dy); + CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx); + CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy); } else if (transpose_x) { - CalcInputGrad(ctx, y, false, false, dout, true, false, dx); - CalcInputGrad(ctx, x, false, false, dout, false, true, dy); + CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx); + CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy); } else if (transpose_y) { - CalcInputGrad(ctx, dout, false, false, y, false, true, dx); - CalcInputGrad(ctx, dout, true, true, x, false, true, dy); + CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx); + CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy); } else { - CalcInputGrad(ctx, dout, false, false, y, true, false, dx); - CalcInputGrad(ctx, x, true, true, dout, false, true, dy); + CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx); + CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy); } if (dx) { @@ -573,40 +639,44 @@ class MatMulV2GradKernel : public framework::OpKernel { VLOG(3) << "It need cost much time to reduce sum for the broadcast and " "wastes the memory. So we should avoid the case in reality"; Tensor dx_help, dy_help; + + ConjHelper conj_helper(ctx); + conj_helper(x, x_conj); + conj_helper(y, y_conj); if (transpose_x) { if (transpose_y) { // X'Y': dA = Y'G', dB = G'X' if (dx) - MatMulFunction(&y, &dout, y_dims, dout_dims, + MatMulFunction(&y_conj, &dout, y_dims, dout_dims, &dx_help, true, true, ctx); if (dy) - MatMulFunction(&dout, &x, dout_dims, x_dims, + MatMulFunction(&dout, &x_conj, dout_dims, x_dims, &dy_help, true, true, ctx); } else { // X'Y: dX = YG', dY = XG if (dx) - MatMulFunction(&y, &dout, y_dims, dout_dims, + MatMulFunction(&y_conj, &dout, y_dims, dout_dims, &dx_help, false, true, ctx); if (dy) - MatMulFunction(&x, &dout, x_dims, dout_dims, + MatMulFunction(&x_conj, &dout, x_dims, dout_dims, &dy_help, false, false, ctx); } } else { if (transpose_y) { // XY': dX = GY, dY = G'X if (dx) - MatMulFunction(&dout, &y, dout_dims, y_dims, + MatMulFunction(&dout, &y_conj, dout_dims, y_dims, &dx_help, false, false, ctx); if (dy) - MatMulFunction(&dout, &x, dout_dims, x_dims, + MatMulFunction(&dout, &x_conj, dout_dims, x_dims, &dy_help, true, false, ctx); } else { // XY: dX = GY', dY = X'G if (dx) - MatMulFunction(&dout, &y, dout_dims, y_dims, + MatMulFunction(&dout, &y_conj, dout_dims, y_dims, &dx_help, false, true, ctx); if (dy) - MatMulFunction(&x, &dout, x_dims, dout_dims, + MatMulFunction(&x_conj, &dout, x_dims, dout_dims, &dy_help, true, false, ctx); } } diff --git a/python/paddle/fluid/tests/unittests/test_dot_op.py b/python/paddle/fluid/tests/unittests/test_dot_op.py index d95f818a62b..f65301f2d86 100644 --- a/python/paddle/fluid/tests/unittests/test_dot_op.py +++ b/python/paddle/fluid/tests/unittests/test_dot_op.py @@ -101,5 +101,127 @@ class TestDygraph(unittest.TestCase): paddle.dot(x1, y1).numpy(), np.array([[17], [58]]))) +class TestComplexDotOp(OpTest): + def setUp(self): + self.op_type = "dot" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random(100).astype( + self.dtype) + 1J * np.random.random(100).astype(self.dtype) + self.y = np.random.random(100).astype( + self.dtype) + 1J * np.random.random(100).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones(1, self.dtype) + 1J * np.ones(1, self.dtype) + self.grad_x = self.grad_out * np.conj(self.y) + self.grad_y = self.grad_out * np.conj(self.x) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestComplexDotOp2D(OpTest): + def setUp(self): + self.op_type = "dot" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (2, 100)).astype(self.dtype) + 1J * np.random.random( + (2, 100)).astype(self.dtype) + self.y = np.random.random( + (2, 100)).astype(self.dtype) + 1J * np.random.random( + (2, 100)).astype(self.dtype) + self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1) + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 1), self.dtype) + 1J * np.ones( + (2, 1), self.dtype) + self.grad_x = self._get_grad(self.grad_out, self.y) + self.grad_y = self._get_grad(self.grad_out, self.x) + + def _get_grad(self, grad_out, input): + grad = np.empty((0, input.shape[1])) + for i in range(grad_out.shape[0]): + grad = np.append(grad, [grad_out[i] * np.conj(input[i])], axis=0) + return grad + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index f93802c47c9..32860a6694a 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -320,6 +320,21 @@ class TestComplexElementwiseDivOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestRealComplexElementwiseDivOp(TestComplexElementwiseDivOp): + def init_input_output(self): + self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x / self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = np.real(self.grad_out / np.conj(self.y)) + self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index f69fa7084ed..7bace9bc535 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -304,6 +304,21 @@ class TestComplexElementwiseMulOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestRealComplexElementwiseMulOp(TestComplexElementwiseMulOp): + def init_input_output(self): + self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x * self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = np.real(self.grad_out * np.conj(self.y)) + self.grad_y = self.grad_out * np.conj(self.x) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py index 6434807c551..c5372d5b758 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle from op_test import OpTest, skip_check_grad_ci @@ -164,5 +165,78 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp): } +class TestComplexElementwiseSubOp(OpTest): + def setUp(self): + self.op_type = "elementwise_sub" + self.dtype = np.float64 + self.shape = (2, 3, 4, 5) + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.y = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.out = self.x - self.y + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones( + self.shape, self.dtype) + self.grad_x = self.grad_out + self.grad_y = -self.grad_out + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestRealComplexElementwiseSubOp(TestComplexElementwiseSubOp): + def init_input_output(self): + self.x = np.random.random(self.shape).astype(self.dtype) + self.y = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.out = self.x - self.y + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones( + self.shape, self.dtype) + self.grad_x = np.real(self.grad_out) + self.grad_y = -self.grad_out + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_kron_op.py b/python/paddle/fluid/tests/unittests/test_kron_op.py index 68ad35489ce..d6db4c2f074 100644 --- a/python/paddle/fluid/tests/unittests/test_kron_op.py +++ b/python/paddle/fluid/tests/unittests/test_kron_op.py @@ -102,5 +102,104 @@ class TestKronLayer(unittest.TestCase): np.testing.assert_allclose(c, np.kron(a, b)) +class TestComplexKronOp(OpTest): + def setUp(self): + self.op_type = "kron" + self.x_shape = np.array([10, 10]) + self.y_shape = np.array([3, 35]) + self.out_shape = self.x_shape * self.y_shape + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random(self.x_shape).astype( + self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype( + self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype) + self.out = np.kron(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones( + self.out_shape, self.dtype) + self.grad_x = self.get_grad_x_by_numpy() + self.grad_y = self.get_grad_y_by_numpy() + + def get_grad_x_by_numpy(self): + grad_x = np.zeros(self.x_shape, np.complex) + for x_i in range(self.x_shape[0]): + for x_j in range(self.x_shape[1]): + for i in range(self.y_shape[0]): + for j in range(self.y_shape[1]): + idx_i = x_i * self.y_shape[0] + i + idx_j = x_j * self.y_shape[1] + j + grad_x[x_i][x_j] += self.grad_out[idx_i][ + idx_j] * np.conj(self.y[i][j]) + return grad_x + + def get_grad_y_by_numpy(self): + grad_y = np.zeros(self.y_shape, np.complex) + for y_i in range(self.y_shape[0]): + for y_j in range(self.y_shape[1]): + for x_i in range(self.x_shape[0]): + for x_j in range(self.x_shape[1]): + idx_i = x_i * self.y_shape[0] + y_i + idx_j = x_j * self.y_shape[1] + y_j + grad_y[y_i][y_j] += self.grad_out[idx_i][ + idx_j] * np.conj(self.x[x_i][x_j]) + return grad_y + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestKronOpTypePromotion(TestComplexKronOp): + def init_input_output(self): + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype( + self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype) + self.out = np.kron(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones( + self.out_shape, self.dtype) + self.grad_x = self.get_grad_x_by_numpy().real + self.grad_y = self.get_grad_y_by_numpy() + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index c137a4a9feb..7b4f16cec01 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -405,5 +405,141 @@ class TestMatMulV2API(unittest.TestCase): result = paddle.matmul(x, y) +class TestComplexMatMulOp(OpTest): + def setUp(self): + self.op_type = "matmul_v2" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (10, 10)).astype(self.dtype) + 1J * np.random.random( + (10, 10)).astype(self.dtype) + self.y = np.random.random( + (10, 10)).astype(self.dtype) + 1J * np.random.random( + (10, 10)).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones( + (10, 10), self.dtype) + self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T) + self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestComplexMatMulOpBroadcast(OpTest): + def setUp(self): + self.op_type = "matmul_v2" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (10, 2, 5)).astype(self.dtype) + 1J * np.random.random( + (10, 2, 5)).astype(self.dtype) + self.y = np.random.random( + (5, 20)).astype(self.dtype) + 1J * np.random.random( + (5, 20)).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones((10, 2, 20), self.dtype) + 1J * np.ones( + (10, 2, 20), self.dtype) + self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T) + self.grad_y = np.sum(np.matmul( + np.conj(self.x).transpose(0, 2, 1), self.grad_out), + axis=0) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestMatMulTypePromotion(TestComplexMatMulOp): + def init_input_output(self): + self.x = np.random.random((10, 10)).astype(self.dtype) + self.y = np.random.random( + (10, 10)).astype(self.dtype) + 1J * np.random.random( + (10, 10)).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones( + (10, 10), self.dtype) + self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real + self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out) + + if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index 330cf5a72b1..15ba331e9de 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -59,6 +59,7 @@ NEED_TO_FIX_OP_LIST = [ 'lstmp', 'margin_rank_loss', 'matmul', + 'matmul_v2', 'mul', 'multiplex', 'rank_loss', -- GitLab