From 25d44d40acfca5ed92dbc57fbaa2b01367a66f99 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 28 Dec 2018 14:17:33 +0800 Subject: [PATCH] sum op support empty selected rows as input --- paddle/fluid/operators/math/selected_rows_functor.cc | 4 ++++ paddle/fluid/operators/sum_op.cc | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 1a11b584e..5f169dda2 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -195,6 +195,10 @@ struct SelectedRowsAddToTensor { void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input1, framework::Tensor* input2) { + if (input1.rows().size() == 0) { + LOG(WARNING) << "input selected rows is empty!"; + return; + } auto in1_height = input1.height(); auto in2_dims = input2->dims(); PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 4f717a435..83afe5819 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -41,7 +41,9 @@ class SumOp : public framework::OperatorWithKernel { return; // skip runtime infershape when is tensor array; } + auto x_var_types = ctx->GetInputsVarType("X"); auto x_dims = ctx->GetInputsDim("X"); + size_t N = x_dims.size(); PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0."); if (N == 1) { @@ -49,7 +51,11 @@ class SumOp : public framework::OperatorWithKernel { } framework::DDim in_dim({0}); - for (auto& x_dim : x_dims) { + for (size_t i = 0; i < x_dims.size(); ++i) { + if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS) { + continue; + } + auto& x_dim = x_dims[i]; if (framework::product(x_dim) == 0) { continue; } -- GitLab