diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index a98d56d6fcff091a9e35e14e130dba50dd0d2174..6f7aeb63b1ced096954d64c2882dabfca808acd6 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -99,17 +99,15 @@ class KronGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", framework::GradVarName("Out"), "kron_grad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", - framework::GradVarName("X"), "kron_grad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Y")), "Output", - framework::GradVarName("Y"), "kron_grad"); auto x_grad_name = framework::GradVarName("X"); auto y_grad_name = framework::GradVarName("Y"); - ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ x_grad_name); - ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); - ctx->ShareLoD("Y", /*->*/ y_grad_name); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); + } } protected: diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index ec7a8a7d9bb4dbd43843b5ea43fe781c92851d54..62762f3f049b6d60f0e4853bcdd46e39369c1158 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -147,11 +147,14 @@ struct KronGradElemFunctor { index_b += stride_b_[i] * pos_bi; } - size_t index_out_a = index_a * numel_b_ + index_b; - size_t index_out_b = index_b * numel_a_ + index_a; - - dout_a_[index_out_a] = dout_[idx] * B_[index_b]; - dout_b_[index_out_b] = dout_[idx] * A_[index_a]; + if (dout_a_) { + size_t index_out_a = index_a * numel_b_ + index_b; + dout_a_[index_out_a] = dout_[idx] * B_[index_b]; + } + if (dout_b_) { + size_t index_out_b = index_b * numel_a_ + index_a; + dout_b_[index_out_b] = dout_[idx] * A_[index_a]; + } } private: @@ -222,35 +225,50 @@ struct KronGradOpFunctor { // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) framework::Tensor dout_x; - dout_x.mutable_data({numel_x, numel_y}, dev_ctx.GetPlace()); + T* p_dout_x = nullptr; + if (dx) { + dout_x.mutable_data({numel_x, numel_y}, dev_ctx.GetPlace()); + p_dout_x = dout_x.data(); + } framework::Tensor dout_y; - dout_y.mutable_data({numel_y, numel_x}, dev_ctx.GetPlace()); + T* p_dout_y = nullptr; + if (dy) { + dout_y.mutable_data({numel_y, numel_x}, dev_ctx.GetPlace()); + p_dout_y = dout_y.data(); + } platform::ForRange for_range(dev_ctx, numel); KronGradElemFunctor func(dout.data(), x.data(), y.data(), - dout_x.data(), dout_y.data(), - p_stride_dout, p_stride_x, p_stride_y, - p_shape_y, numel_x, numel_y, ndims); + p_dout_x, p_dout_y, p_stride_dout, p_stride_x, + p_stride_y, p_shape_y, numel_x, numel_y, ndims); for_range(func); // reduce_sum along aixs 1 #if __NVCC__ auto stream = dev_ctx.stream(); // it is a cuda device_context - TensorReduce>( - dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), - stream); - TensorReduce>( - dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), - stream); + if (dx) { + TensorReduce>( + dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + stream); + } + if (dy) { + TensorReduce>( + dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + stream); + } #else - auto eigen_dout_x = framework::EigenMatrix::Reshape(dout_x, 1); - auto eigen_dout_y = framework::EigenMatrix::Reshape(dout_y, 1); - auto eigen_vec_dx = framework::EigenVector::Flatten(*dx); - auto eigen_vec_dy = framework::EigenVector::Flatten(*dy); auto* place = dev_ctx.eigen_device(); Eigen::array reduce_dim = {1}; - eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); - eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); + if (dx) { + auto eigen_dout_x = framework::EigenMatrix::Reshape(dout_x, 1); + auto eigen_vec_dx = framework::EigenVector::Flatten(*dx); + eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); + } + if (dy) { + auto eigen_dout_y = framework::EigenMatrix::Reshape(dout_y, 1); + auto eigen_vec_dy = framework::EigenVector::Flatten(*dy); + eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); + } #endif } }; @@ -307,17 +325,33 @@ class KronGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - dx->mutable_data(ctx.GetPlace()); - dy->mutable_data(ctx.GetPlace()); + if (dx) { + dx->mutable_data(ctx.GetPlace()); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + } int ndims = dout->dims().size(); framework::Tensor xx = UnsqueezeTo(*x, ndims); - framework::Tensor dxx = UnsqueezeTo(*dx, ndims); framework::Tensor yy = UnsqueezeTo(*y, ndims); - framework::Tensor dyy = UnsqueezeTo(*dy, ndims); + + framework::Tensor* pdxx = nullptr; + framework::Tensor* pdyy = nullptr; + framework::Tensor dxx; + framework::Tensor dyy; + if (dx) { + dxx = UnsqueezeTo(*dx, ndims); + pdxx = &dxx; + } + + if (dy) { + dyy = UnsqueezeTo(*dy, ndims); + pdyy = &dyy; + } KronGradOpFunctor func; - func(dev_ctx, *dout, xx, yy, &dxx, &dyy); + func(dev_ctx, *dout, xx, yy, pdxx, pdyy); } }; diff --git a/python/paddle/complex/tensor/linalg.py b/python/paddle/complex/tensor/linalg.py index 99c0ca5fbfab32c1581878f02db10d3338a507c8..fedfde5e5c2e00741b32fb4e2dc6ba58d5abb263 100644 --- a/python/paddle/complex/tensor/linalg.py +++ b/python/paddle/complex/tensor/linalg.py @@ -26,10 +26,10 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): Args: x (ComplexVariable|Variable): The first input, can be a ComplexVariable - with data type complex32 or complex64, or a Variable with data type + with data type complex64 or complex128, or a Variable with data type float32 or float64. y (ComplexVariable|Variable): The second input, can be a ComplexVariable - with data type complex32 or complex64, or a Variable with data type + with data type complex64 or complex128, or a Variable with data type float32 or float64. transpose_x (bool): Whether to transpose :math:`x` before multiplication. transpose_y (bool): Whether to transpose :math:`y` before multiplication. diff --git a/python/paddle/complex/tensor/math.py b/python/paddle/complex/tensor/math.py index 302782281c79b336b54657976c634fe1ecf5c09d..aa4d6a0fc5500d0210b37e74a47c0d8730c0666c 100644 --- a/python/paddle/complex/tensor/math.py +++ b/python/paddle/complex/tensor/math.py @@ -367,6 +367,7 @@ def kron(x, y, name=None): import numpy as np import paddle + from paddle import fluid import paddle.fluid.dygraph as dg a = np.array([[1.0+1.0j, 2.0+1.0j], [3.0+1.0j, 4.0+1.0j]]) diff --git a/python/paddle/fluid/tests/unittests/test_kron_op.py b/python/paddle/fluid/tests/unittests/test_kron_op.py index 2474c0738c1424869e64791cc82699f288f398b0..556d68955328f072b2fedc66d668c9a9875f57f7 100644 --- a/python/paddle/fluid/tests/unittests/test_kron_op.py +++ b/python/paddle/fluid/tests/unittests/test_kron_op.py @@ -42,6 +42,12 @@ class TestKronOp(OpTest): def test_check_grad(self): self.check_grad(['X', 'Y'], 'Out') + def test_check_grad_ignore_x(self): + self.check_grad(['Y'], 'Out', no_grad_set=set('X')) + + def test_check_grad_ignore_y(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + class TestKronOp2(TestKronOp): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index eb1471e377a0b7e1bd2eb59a37ab3eba5c6b59e7..330cf5a72b1a56f674c82d4fb9e502785e2cf5c0 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -13,7 +13,7 @@ # limitations under the License. # check no_grad_set is None -NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv'] +NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv', 'kron'] # TODO(Shixiaowei02): Check if the items do not need fix. # no_grad_set has value in NEED_TO_FIX_OP_LIST