未验证 提交 2f639113 编写于 作者: C chengduo 提交者: GitHub

Fix sum_op's GetExpectedKernelType (#14112)

* fix sum_op's GetExpectedKernelType
test=develop

* fix ci fail
test=develop
上级 5577f9b0
......@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}
static const Tensor* GetTensorFromVar(Variable* var) {
const Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
......
......@@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(Variable* var);
class OperatorBase;
class ExecutionContext;
......
......@@ -82,14 +82,16 @@ class SumOp : public framework::OperatorWithKernel {
if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1;
for (auto& x_var : x_vars) {
auto& lod_tensor = x_var->Get<framework::LoDTensor>();
if (lod_tensor.numel() == 0) {
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
auto tensor = framework::GetTensorFromVar(
const_cast<framework::Variable*>(x_var));
if (tensor->numel() == 0) {
continue;
}
if (dtype == -1) {
dtype = framework::ToDataType(lod_tensor.type());
dtype = framework::ToDataType(tensor->type());
} else {
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type()));
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type()));
}
}
PADDLE_ENFORCE_NE(dtype, -1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册