From 78cd3dd50730e0e73ab1395844ee07861d8f01e8 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Thu, 30 Apr 2020 13:31:55 +0800 Subject: [PATCH] fix kron_op: when only one input needs gradient, test=develop (#24269) fix kron_op: when only one input needs gradient --- paddle/fluid/operators/kron_op.cc | 14 ++- paddle/fluid/operators/kron_op.h | 88 +++++++++++++------ .../fluid/tests/unittests/test_kron_op.py | 6 ++ .../white_list/no_grad_set_white_list.py | 2 +- 4 files changed, 74 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index a98d56d6fc..6f7aeb63b1 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -99,17 +99,15 @@ class KronGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", framework::GradVarName("Out"), "kron_grad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", - framework::GradVarName("X"), "kron_grad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Y")), "Output", - framework::GradVarName("Y"), "kron_grad"); auto x_grad_name = framework::GradVarName("X"); auto y_grad_name = framework::GradVarName("Y"); - ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ x_grad_name); - ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); - ctx->ShareLoD("Y", /*->*/ y_grad_name); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); + } } protected: diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index ec7a8a7d9b..62762f3f04 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -147,11 +147,14 @@ struct KronGradElemFunctor { index_b += stride_b_[i] * pos_bi; } - size_t index_out_a = index_a * numel_b_ + index_b; - size_t index_out_b = index_b * numel_a_ + index_a; - - dout_a_[index_out_a] = dout_[idx] * B_[index_b]; - dout_b_[index_out_b] = dout_[idx] * A_[index_a]; + if (dout_a_) { + size_t index_out_a = index_a * numel_b_ + index_b; + dout_a_[index_out_a] = dout_[idx] * B_[index_b]; + } + if (dout_b_) { + size_t index_out_b = index_b * numel_a_ + index_a; + dout_b_[index_out_b] = dout_[idx] * A_[index_a]; + } } private: @@ -222,35 +225,50 @@ struct KronGradOpFunctor { // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) framework::Tensor dout_x; - dout_x.mutable_data({numel_x, numel_y}, dev_ctx.GetPlace()); + T* p_dout_x = nullptr; + if (dx) { + dout_x.mutable_data({numel_x, numel_y}, dev_ctx.GetPlace()); + p_dout_x = dout_x.data(); + } framework::Tensor dout_y; - dout_y.mutable_data({numel_y, numel_x}, dev_ctx.GetPlace()); + T* p_dout_y = nullptr; + if (dy) { + dout_y.mutable_data({numel_y, numel_x}, dev_ctx.GetPlace()); + p_dout_y = dout_y.data(); + } platform::ForRange for_range(dev_ctx, numel); KronGradElemFunctor func(dout.data(), x.data(), y.data(), - dout_x.data(), dout_y.data(), - p_stride_dout, p_stride_x, p_stride_y, - p_shape_y, numel_x, numel_y, ndims); + p_dout_x, p_dout_y, p_stride_dout, p_stride_x, + p_stride_y, p_shape_y, numel_x, numel_y, ndims); for_range(func); // reduce_sum along aixs 1 #if __NVCC__ auto stream = dev_ctx.stream(); // it is a cuda device_context - TensorReduce>( - dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), - stream); - TensorReduce>( - dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), - stream); + if (dx) { + TensorReduce>( + dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + stream); + } + if (dy) { + TensorReduce>( + dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + stream); + } #else - auto eigen_dout_x = framework::EigenMatrix::Reshape(dout_x, 1); - auto eigen_dout_y = framework::EigenMatrix::Reshape(dout_y, 1); - auto eigen_vec_dx = framework::EigenVector::Flatten(*dx); - auto eigen_vec_dy = framework::EigenVector::Flatten(*dy); auto* place = dev_ctx.eigen_device(); Eigen::array reduce_dim = {1}; - eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); - eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); + if (dx) { + auto eigen_dout_x = framework::EigenMatrix::Reshape(dout_x, 1); + auto eigen_vec_dx = framework::EigenVector::Flatten(*dx); + eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); + } + if (dy) { + auto eigen_dout_y = framework::EigenMatrix::Reshape(dout_y, 1); + auto eigen_vec_dy = framework::EigenVector::Flatten(*dy); + eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); + } #endif } }; @@ -307,17 +325,33 @@ class KronGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - dx->mutable_data(ctx.GetPlace()); - dy->mutable_data(ctx.GetPlace()); + if (dx) { + dx->mutable_data(ctx.GetPlace()); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + } int ndims = dout->dims().size(); framework::Tensor xx = UnsqueezeTo(*x, ndims); - framework::Tensor dxx = UnsqueezeTo(*dx, ndims); framework::Tensor yy = UnsqueezeTo(*y, ndims); - framework::Tensor dyy = UnsqueezeTo(*dy, ndims); + + framework::Tensor* pdxx = nullptr; + framework::Tensor* pdyy = nullptr; + framework::Tensor dxx; + framework::Tensor dyy; + if (dx) { + dxx = UnsqueezeTo(*dx, ndims); + pdxx = &dxx; + } + + if (dy) { + dyy = UnsqueezeTo(*dy, ndims); + pdyy = &dyy; + } KronGradOpFunctor func; - func(dev_ctx, *dout, xx, yy, &dxx, &dyy); + func(dev_ctx, *dout, xx, yy, pdxx, pdyy); } }; diff --git a/python/paddle/fluid/tests/unittests/test_kron_op.py b/python/paddle/fluid/tests/unittests/test_kron_op.py index 57076b7551..1047f1bf1e 100644 --- a/python/paddle/fluid/tests/unittests/test_kron_op.py +++ b/python/paddle/fluid/tests/unittests/test_kron_op.py @@ -42,6 +42,12 @@ class TestKronOp(OpTest): def test_check_grad(self): self.check_grad(['X', 'Y'], 'Out') + def test_check_grad_ignore_x(self): + self.check_grad(['Y'], 'Out', no_grad_set=set('X')) + + def test_check_grad_ignore_y(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + class TestKronOp2(TestKronOp): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index eb1471e377..330cf5a72b 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -13,7 +13,7 @@ # limitations under the License. # check no_grad_set is None -NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv'] +NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv', 'kron'] # TODO(Shixiaowei02): Check if the items do not need fix. # no_grad_set has value in NEED_TO_FIX_OP_LIST -- GitLab