From e012930aa375fc412617cf94ff05ef454e87e99e Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 31 Dec 2020 14:39:57 +0800 Subject: [PATCH] complex gradient matmul (#29966) * dot op support complex types * matmul support complex types * add test case * matmul broadcast gradient support complex * move conjFunctor to complex_functor.h --- paddle/fluid/operators/conj_op.h | 40 +-- paddle/fluid/operators/dot_op.cc | 12 +- paddle/fluid/operators/dot_op.cu | 23 +- paddle/fluid/operators/dot_op.h | 256 ++++++++++++++---- .../fluid/operators/math/complex_functors.h | 37 +++ paddle/fluid/operators/matmul_v2_op.h | 104 +++++-- .../fluid/tests/unittests/test_dot_op.py | 122 +++++++++ .../tests/unittests/test_matmul_v2_op.py | 121 +++++++++ .../white_list/no_grad_set_white_list.py | 1 + 9 files changed, 591 insertions(+), 125 deletions(-) diff --git a/paddle/fluid/operators/conj_op.h b/paddle/fluid/operators/conj_op.h index 0bec7b707e3..417a136c60b 100644 --- a/paddle/fluid/operators/conj_op.h +++ b/paddle/fluid/operators/conj_op.h @@ -17,49 +17,13 @@ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EnableComplex = - typename std::enable_if::value || - std::is_same::value>::type; - -template -using DisableComplex = typename std::enable_if< - !std::is_same::value && - !std::is_same::value>::type; - -template -struct ConjFunctor; - -template -struct ConjFunctor> { - ConjFunctor(const T* input, int64_t numel, T* output) - : input_(input), numel_(numel), output_(output) {} - - HOSTDEVICE void operator()(size_t idx) const { - output_[idx] = T(input_[idx].real, -input_[idx].imag); - } - const T* input_; - int64_t numel_; - T* output_; -}; - -template -struct ConjFunctor> { - ConjFunctor(const T* input, int64_t numel, T* output) - : input_(input), numel_(numel), output_(output) {} - - HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; } - const T* input_; - int64_t numel_; - T* output_; -}; - template class ConjKernel : public framework::OpKernel { public: @@ -74,7 +38,7 @@ class ConjKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); platform::ForRange for_range(dev_ctx, numel); - ConjFunctor functor(x_data, numel, out_data); + math::ConjFunctor functor(x_data, numel, out_data); for_range(functor); } }; diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index 0527445adf0..26f12e8f9e3 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -152,9 +152,17 @@ REGISTER_OP_CPU_KERNEL( dot, ops::DotKernel, ops::DotKernel, ops::DotKernel, - ops::DotKernel); + ops::DotKernel, + ops::DotKernel, + ops::DotKernel); REGISTER_OP_CPU_KERNEL( dot_grad, ops::DotGradKernel, ops::DotGradKernel, ops::DotGradKernel, - ops::DotGradKernel); + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel); diff --git a/paddle/fluid/operators/dot_op.cu b/paddle/fluid/operators/dot_op.cu index eb7ebbe32d7..2d259ba1fbc 100644 --- a/paddle/fluid/operators/dot_op.cu +++ b/paddle/fluid/operators/dot_op.cu @@ -17,12 +17,17 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(dot, ops::DotKernel, - ops::DotKernel, - ops::DotKernel, - ops::DotKernel); -REGISTER_OP_CUDA_KERNEL(dot_grad, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel); +REGISTER_OP_CUDA_KERNEL( + dot, ops::DotKernel, + ops::DotKernel, + ops::DotKernel, + ops::DotKernel, + ops::DotKernel, + ops::DotKernel); +REGISTER_OP_CUDA_KERNEL( + dot_grad, ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel); diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index cec706300d7..c78ac87084c 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -16,95 +16,233 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using complex64 = platform::complex64; +using complex128 = platform::complex128; template using EigenMatrix = framework::EigenMatrix; +template +struct P { + void operator()(T a, R b); +}; + +template +struct DotGradFunction { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + const Tensor* tensor_dout, Tensor* tensor_dx, + Tensor* tensor_dy, + const paddle::framework::ExecutionContext& ctx); +}; + template -void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { +struct DotGradFunction> { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + const Tensor* tensor_dout, Tensor* tensor_dx, + Tensor* tensor_dy, + const paddle::framework::ExecutionContext& ctx) { #ifdef __NVCC__ - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); + if (1 == tensor_dout->dims().size()) { + auto dout = framework::EigenVector::Flatten(*tensor_dout); - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto& dev = *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - dx.device(dev) = y * dout.broadcast(size); - } + if (tensor_dx) { + auto y = framework::EigenVector::Flatten(*tensor_y); + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto& dev = *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - dy.device(dev) = x * dout.broadcast(size); + paddle::platform::ForRange for_range(dev_raw, + tensor_y->numel()); + math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), + tensor_dx->data()); + for_range(functor); + auto dx = framework::EigenVector::Flatten(*tensor_dx); + + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = framework::EigenVector::Flatten(*tensor_x); + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + + paddle::platform::ForRange for_range(dev_raw, + tensor_y->numel()); + math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), + tensor_dy->data()); + for_range(functor); + auto dy = framework::EigenVector::Flatten(*tensor_dy); + + dy.device(dev) = dy * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(ctx.GetPlace()); + auto y = EigenMatrix::From(*tensor_y); + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + + paddle::platform::ForRange for_range(dev_raw, + tensor_y->numel()); + math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), + tensor_dx->data()); + for_range(functor); + auto dx = EigenMatrix::From(*tensor_dx); + + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(ctx.GetPlace()); + auto x = EigenMatrix::From(*tensor_x); + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + + paddle::platform::ForRange for_range(dev_raw, + tensor_x->numel()); + math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), + tensor_dy->data()); + for_range(functor); + + auto dy = EigenMatrix::From(*tensor_dy); + + dy.device(dev) = dy * dout.broadcast(size); + } } - } else { - auto dout = EigenMatrix::From(*tensor_dout); +#else + const auto* data_dout = tensor_dout->data(); if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = EigenMatrix::From(*tensor_y); - auto dx = EigenMatrix::From(*tensor_dx); - auto& dev = *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - dx.device(dev) = y * dout.broadcast(size); + auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); + const auto* data_y = tensor_y->data(); + const framework::DDim& dim = tensor_x->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s]; + } } if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = EigenMatrix::From(*tensor_x); - auto dy = EigenMatrix::From(*tensor_dy); - auto& dev = *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - dy.device(dev) = x * dout.broadcast(size); + auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); + const auto* data_x = tensor_x->data(); + const framework::DDim& dim = tensor_y->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s]; + } } +#endif } +}; + +template +struct DotGradFunction> { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + const Tensor* tensor_dout, Tensor* tensor_dx, + Tensor* tensor_dy, + const paddle::framework::ExecutionContext& ctx) { +#ifdef __NVCC__ + if (1 == tensor_dout->dims().size()) { + auto dout = framework::EigenVector::Flatten(*tensor_dout); + + if (tensor_dx) { + auto y = framework::EigenVector::Flatten(*tensor_y); + auto dx = framework::EigenVector::Flatten(*tensor_dx); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = framework::EigenVector::Flatten(*tensor_x); + auto dy = framework::EigenVector::Flatten(*tensor_dy); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + dy.device(dev) = x * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(ctx.GetPlace()); + auto y = EigenMatrix::From(*tensor_y); + auto dx = EigenMatrix::From(*tensor_dx); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(ctx.GetPlace()); + auto x = EigenMatrix::From(*tensor_x); + auto dy = EigenMatrix::From(*tensor_dy); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + dy.device(dev) = x * dout.broadcast(size); + } + } #else - const auto* data_dout = tensor_dout->data(); + const auto* data_dout = tensor_dout->data(); - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_y = tensor_y->data(); - const framework::DDim& dim = tensor_x->dims(); - size_t N = static_cast(framework::product(dim)); + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); + const auto* data_y = tensor_y->data(); + const framework::DDim& dim = tensor_x->dims(); + size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; + auto step = dim[dim.size() - 1]; - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = data_y[i] * data_dout[s]; + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = data_y[i] * data_dout[s]; + } } - } - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_x = tensor_x->data(); - const framework::DDim& dim = tensor_y->dims(); - size_t N = static_cast(framework::product(dim)); + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); + const auto* data_x = tensor_x->data(); + const framework::DDim& dim = tensor_y->dims(); + size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; + auto step = dim[dim.size() - 1]; - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = data_x[i] * data_dout[s]; + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = data_x[i] * data_dout[s]; + } } - } #endif -} + } +}; template class DotKernel : public framework::OpKernel { @@ -165,8 +303,8 @@ class DotGradKernel : public framework::OpKernel { if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); - DotGradFunction(tensor_x, tensor_y, tensor_dout, - tensor_dx, tensor_dy, ctx); + DotGradFunction()(tensor_x, tensor_y, tensor_dout, + tensor_dx, tensor_dy, ctx); } }; diff --git a/paddle/fluid/operators/math/complex_functors.h b/paddle/fluid/operators/math/complex_functors.h index 302e3d562c6..18a003d5c9a 100644 --- a/paddle/fluid/operators/math/complex_functors.h +++ b/paddle/fluid/operators/math/complex_functors.h @@ -135,6 +135,43 @@ struct ImagToComplexFunctor>> { int64_t numel_; }; +template +using EnableComplex = + typename std::enable_if::value || + std::is_same::value>::type; + +template +using DisableComplex = typename std::enable_if< + !std::is_same::value && + !std::is_same::value>::type; + +template +struct ConjFunctor; + +template +struct ConjFunctor> { + ConjFunctor(const T* input, int64_t numel, T* output) + : input_(input), numel_(numel), output_(output) {} + + HOSTDEVICE void operator()(size_t idx) const { + output_[idx] = T(input_[idx].real, -input_[idx].imag); + } + const T* input_; + int64_t numel_; + T* output_; +}; + +template +struct ConjFunctor> { + ConjFunctor(const T* input, int64_t numel, T* output) + : input_(input), numel_(numel), output_(output) {} + + HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; } + const T* input_; + int64_t numel_; + T* output_; +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 8a83a29d484..b6eac7bf0cc 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #ifdef __NVCC__ @@ -468,6 +469,61 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } +template +struct ConjHelper { + explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} + HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { + dst.Resize(src.dims()); + dst.set_layout(src.layout()); + dst.ShareDataWith(src); + return; + } + + const framework::ExecutionContext& ctx_; +}; + +template +struct ConjHelper { + explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} + + HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { + dst.Resize(src.dims()); + auto* src_data = src.data(); + auto* dst_data = dst.mutable_data( + ctx_.GetPlace(), + size_t(src.numel() * sizeof(paddle::platform::complex64))); + + platform::ForRange for_range( + ctx_.template device_context(), src.numel()); + math::ConjFunctor functor( + src_data, src.numel(), dst_data); + for_range(functor); + return; + } + const framework::ExecutionContext& ctx_; +}; + +template +struct ConjHelper { + explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} + + HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { + dst.Resize(src.dims()); + auto* src_data = src.data(); + auto* dst_data = dst.mutable_data( + ctx_.GetPlace(), + size_t(src.numel() * sizeof(paddle::platform::complex128))); + + platform::ForRange for_range( + ctx_.template device_context(), src.numel()); + math::ConjFunctor functor( + src_data, src.numel(), dst_data); + for_range(functor); + return; + } + const framework::ExecutionContext& ctx_; +}; + template class MatMulV2GradKernel : public framework::OpKernel { public: @@ -519,6 +575,8 @@ class MatMulV2GradKernel : public framework::OpKernel { auto x = *ctx.Input("X"); auto y = *ctx.Input("Y"); auto dout = *ctx.Input(framework::GradVarName("Out")); + framework::Tensor y_conj(y.type()); + framework::Tensor x_conj(y.type()); // get dims std::vector x_dims = vectorize(x.dims()); @@ -537,7 +595,7 @@ class MatMulV2GradKernel : public framework::OpKernel { if (dx) dx->mutable_data(ctx.GetPlace()); if (dy) dy->mutable_data(ctx.GetPlace()); if (dout.numel() == 1) { - DotGradFunction(&x, &y, &dout, dx, dy, ctx); + DotGradFunction()(&x, &y, &dout, dx, dy, ctx); return; } } @@ -562,6 +620,10 @@ class MatMulV2GradKernel : public framework::OpKernel { if (dx_dims != x.dims()) { dx->Resize(x.dims()); } + + // for complex + ConjHelper conj_helper(ctx); + conj_helper(y, y_conj); } framework::DDim dy_dims; @@ -570,19 +632,23 @@ class MatMulV2GradKernel : public framework::OpKernel { if (dy_dims != y.dims()) { dy->Resize(y.dims()); } + + // for complex + ConjHelper conj_helper(ctx); + conj_helper(x, x_conj); } if (transpose_x && transpose_y) { - CalcInputGrad(ctx, y, true, true, dout, true, false, dx); - CalcInputGrad(ctx, dout, true, true, x, true, false, dy); + CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx); + CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy); } else if (transpose_x) { - CalcInputGrad(ctx, y, false, false, dout, true, false, dx); - CalcInputGrad(ctx, x, false, false, dout, false, true, dy); + CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx); + CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy); } else if (transpose_y) { - CalcInputGrad(ctx, dout, false, false, y, false, true, dx); - CalcInputGrad(ctx, dout, true, true, x, false, true, dy); + CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx); + CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy); } else { - CalcInputGrad(ctx, dout, false, false, y, true, false, dx); - CalcInputGrad(ctx, x, true, true, dout, false, true, dy); + CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx); + CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy); } if (dx) { @@ -602,40 +668,44 @@ class MatMulV2GradKernel : public framework::OpKernel { VLOG(3) << "It need cost much time to reduce sum for the broadcast and " "wastes the memory. So we should avoid the case in reality"; Tensor dx_help, dy_help; + + ConjHelper conj_helper(ctx); + conj_helper(x, x_conj); + conj_helper(y, y_conj); if (transpose_x) { if (transpose_y) { // X'Y': dA = Y'G', dB = G'X' if (dx) - MatMulFunction(&y, &dout, y_dims, dout_dims, + MatMulFunction(&y_conj, &dout, y_dims, dout_dims, &dx_help, true, true, ctx); if (dy) - MatMulFunction(&dout, &x, dout_dims, x_dims, + MatMulFunction(&dout, &x_conj, dout_dims, x_dims, &dy_help, true, true, ctx); } else { // X'Y: dX = YG', dY = XG if (dx) - MatMulFunction(&y, &dout, y_dims, dout_dims, + MatMulFunction(&y_conj, &dout, y_dims, dout_dims, &dx_help, false, true, ctx); if (dy) - MatMulFunction(&x, &dout, x_dims, dout_dims, + MatMulFunction(&x_conj, &dout, x_dims, dout_dims, &dy_help, false, false, ctx); } } else { if (transpose_y) { // XY': dX = GY, dY = G'X if (dx) - MatMulFunction(&dout, &y, dout_dims, y_dims, + MatMulFunction(&dout, &y_conj, dout_dims, y_dims, &dx_help, false, false, ctx); if (dy) - MatMulFunction(&dout, &x, dout_dims, x_dims, + MatMulFunction(&dout, &x_conj, dout_dims, x_dims, &dy_help, true, false, ctx); } else { // XY: dX = GY', dY = X'G if (dx) - MatMulFunction(&dout, &y, dout_dims, y_dims, + MatMulFunction(&dout, &y_conj, dout_dims, y_dims, &dx_help, false, true, ctx); if (dy) - MatMulFunction(&x, &dout, x_dims, dout_dims, + MatMulFunction(&x_conj, &dout, x_dims, dout_dims, &dy_help, true, false, ctx); } } diff --git a/python/paddle/fluid/tests/unittests/test_dot_op.py b/python/paddle/fluid/tests/unittests/test_dot_op.py index d95f818a62b..f65301f2d86 100644 --- a/python/paddle/fluid/tests/unittests/test_dot_op.py +++ b/python/paddle/fluid/tests/unittests/test_dot_op.py @@ -101,5 +101,127 @@ class TestDygraph(unittest.TestCase): paddle.dot(x1, y1).numpy(), np.array([[17], [58]]))) +class TestComplexDotOp(OpTest): + def setUp(self): + self.op_type = "dot" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random(100).astype( + self.dtype) + 1J * np.random.random(100).astype(self.dtype) + self.y = np.random.random(100).astype( + self.dtype) + 1J * np.random.random(100).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones(1, self.dtype) + 1J * np.ones(1, self.dtype) + self.grad_x = self.grad_out * np.conj(self.y) + self.grad_y = self.grad_out * np.conj(self.x) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestComplexDotOp2D(OpTest): + def setUp(self): + self.op_type = "dot" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (2, 100)).astype(self.dtype) + 1J * np.random.random( + (2, 100)).astype(self.dtype) + self.y = np.random.random( + (2, 100)).astype(self.dtype) + 1J * np.random.random( + (2, 100)).astype(self.dtype) + self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1) + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 1), self.dtype) + 1J * np.ones( + (2, 1), self.dtype) + self.grad_x = self._get_grad(self.grad_out, self.y) + self.grad_y = self._get_grad(self.grad_out, self.x) + + def _get_grad(self, grad_out, input): + grad = np.empty((0, input.shape[1])) + for i in range(grad_out.shape[0]): + grad = np.append(grad, [grad_out[i] * np.conj(input[i])], axis=0) + return grad + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index 76172632c71..f944f84c6c1 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -405,5 +405,126 @@ class TestMatMulV2API(unittest.TestCase): result = paddle.matmul(x, y) +class TestComplexMatMulOp(OpTest): + def setUp(self): + self.op_type = "matmul_v2" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (10, 10)).astype(self.dtype) + 1J * np.random.random( + (10, 10)).astype(self.dtype) + self.y = np.random.random( + (10, 10)).astype(self.dtype) + 1J * np.random.random( + (10, 10)).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones( + (10, 10), self.dtype) + self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T) + self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestComplexMatMulOpBroadcast(OpTest): + def setUp(self): + self.op_type = "matmul_v2" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (10, 2, 5)).astype(self.dtype) + 1J * np.random.random( + (10, 2, 5)).astype(self.dtype) + self.y = np.random.random( + (5, 20)).astype(self.dtype) + 1J * np.random.random( + (5, 20)).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones((10, 2, 20), self.dtype) + 1J * np.ones( + (10, 2, 20), self.dtype) + self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T) + self.grad_y = np.sum(np.matmul( + np.conj(self.x).transpose(0, 2, 1), self.grad_out), + axis=0) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index 330cf5a72b1..15ba331e9de 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -59,6 +59,7 @@ NEED_TO_FIX_OP_LIST = [ 'lstmp', 'margin_rank_loss', 'matmul', + 'matmul_v2', 'mul', 'multiplex', 'rank_loss', -- GitLab