提交 69c2acbd 编写于 作者: C chengduozh

fix ci fail

test=release/1.1
上级 dc9e23c4
...@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable* var) { ...@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>(); return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
} }
static const Tensor* GetTensorFromVar(Variable* var) { const Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
......
...@@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(Variable* var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
......
...@@ -81,10 +81,10 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -81,10 +81,10 @@ class SumOp : public framework::OperatorWithKernel {
if (x_vars[0]->IsType<framework::LoDTensor>()) { if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1; int dtype = -1;
auto x_var_names = ctx.Inputs("X"); for (auto& x_var : x_vars) {
for (auto& x_var_n : x_var_names) {
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor. // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
auto tensor = ctx.Input<Tensor>(x_var_n); auto tensor = framework::GetTensorFromVar(
const_cast<framework::Variable*>(x_var));
if (tensor->numel() == 0) { if (tensor->numel() == 0) {
continue; continue;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册