提交 1c43ef49 编写于 作者: C chengduozh

fix GetTensorFromVar

test=release/1.1
上级 69c2acbd
......@@ -358,11 +358,11 @@ static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}
const Tensor* GetTensorFromVar(Variable* var) {
const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
return static_cast<const Tensor*>(&(var->Get<LoDTensor>()));
} else if (var->IsType<SelectedRows>()) {
return var->GetMutable<SelectedRows>()->mutable_value();
return &(var->Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
......@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
return var == nullptr ? nullptr
: GetTensorFromVar(const_cast<Variable*>(var));
return var == nullptr ? nullptr : GetTensorFromVar(var);
}
template <>
......
......@@ -63,7 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(Variable* var);
const Tensor* GetTensorFromVar(const Variable* var);
class OperatorBase;
class ExecutionContext;
......
......@@ -83,8 +83,7 @@ class SumOp : public framework::OperatorWithKernel {
int dtype = -1;
for (auto& x_var : x_vars) {
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
auto tensor = framework::GetTensorFromVar(
const_cast<framework::Variable*>(x_var));
auto tensor = framework::GetTensorFromVar(x_var);
if (tensor->numel() == 0) {
continue;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册