diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 14fcde2fe3b1c3acfc0994e9cd37a784c57826d7..9259bb740a8a2408e4dc7be21711560fdf250752 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable* var) { return var->IsType() || var->IsType(); } -static const Tensor* GetTensorFromVar(Variable* var) { +const Tensor* GetTensorFromVar(Variable* var) { if (var->IsType()) { return var->GetMutable(); } else if (var->IsType()) { diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 626b50edfd39424473be33e9f8baec5970471477..a04d2834eb94c2d8df9c6e48782d10bb3254a6dd 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) { } proto::VarType::Type GetDataTypeOfVar(const Variable* var); +const Tensor* GetTensorFromVar(Variable* var); class OperatorBase; class ExecutionContext; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 34dbac2ab8dcc9bd2b91e2daa2f42806057f5f56..6fe30630e9683f59044b216b9e9b1f7dd647b1e2 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -82,14 +82,16 @@ class SumOp : public framework::OperatorWithKernel { if (x_vars[0]->IsType()) { int dtype = -1; for (auto& x_var : x_vars) { - auto& lod_tensor = x_var->Get(); - if (lod_tensor.numel() == 0) { + // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor. + auto tensor = framework::GetTensorFromVar( + const_cast(x_var)); + if (tensor->numel() == 0) { continue; } if (dtype == -1) { - dtype = framework::ToDataType(lod_tensor.type()); + dtype = framework::ToDataType(tensor->type()); } else { - PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type())); + PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type())); } } PADDLE_ENFORCE_NE(dtype, -1,