未验证 提交 66024e90 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14149 from chengduoZH/fix_sum_op_bug_relase

Fix sum op's GetExpectedKernelType
...@@ -358,11 +358,11 @@ static bool VarIsTensor(const Variable* var) { ...@@ -358,11 +358,11 @@ static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>(); return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
} }
static const Tensor* GetTensorFromVar(Variable* var) { const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return static_cast<const Tensor*>(&(var->Get<LoDTensor>()));
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
return var->GetMutable<SelectedRows>()->mutable_value(); return &(var->Get<SelectedRows>().value());
} else { } else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name()); var->Type().name());
...@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const { ...@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
return var == nullptr ? nullptr return var == nullptr ? nullptr : GetTensorFromVar(var);
: GetTensorFromVar(const_cast<Variable*>(var));
} }
template <> template <>
......
...@@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(const Variable* var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
......
...@@ -82,14 +82,15 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -82,14 +82,15 @@ class SumOp : public framework::OperatorWithKernel {
if (x_vars[0]->IsType<framework::LoDTensor>()) { if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1; int dtype = -1;
for (auto& x_var : x_vars) { for (auto& x_var : x_vars) {
auto& lod_tensor = x_var->Get<framework::LoDTensor>(); // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
if (lod_tensor.numel() == 0) { auto tensor = framework::GetTensorFromVar(x_var);
if (tensor->numel() == 0) {
continue; continue;
} }
if (dtype == -1) { if (dtype == -1) {
dtype = framework::ToDataType(lod_tensor.type()); dtype = framework::ToDataType(tensor->type());
} else { } else {
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type())); PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type()));
} }
} }
PADDLE_ENFORCE_NE(dtype, -1, PADDLE_ENFORCE_NE(dtype, -1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册