提交 dc9e23c4 编写于 作者: C chengduozh

fix sum_op's GetExpectedKernelType

test=develop
上级 77193498
...@@ -81,15 +81,17 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -81,15 +81,17 @@ class SumOp : public framework::OperatorWithKernel {
if (x_vars[0]->IsType<framework::LoDTensor>()) { if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1; int dtype = -1;
for (auto& x_var : x_vars) { auto x_var_names = ctx.Inputs("X");
auto& lod_tensor = x_var->Get<framework::LoDTensor>(); for (auto& x_var_n : x_var_names) {
if (lod_tensor.numel() == 0) { // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
auto tensor = ctx.Input<Tensor>(x_var_n);
if (tensor->numel() == 0) {
continue; continue;
} }
if (dtype == -1) { if (dtype == -1) {
dtype = framework::ToDataType(lod_tensor.type()); dtype = framework::ToDataType(tensor->type());
} else { } else {
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type())); PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type()));
} }
} }
PADDLE_ENFORCE_NE(dtype, -1, PADDLE_ENFORCE_NE(dtype, -1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册