提交 72aef6b1 编写于 作者: Q Qiao Longfei

sum selected rows check empty

上级 f13ae131
...@@ -116,8 +116,22 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -116,8 +116,22 @@ class SumKernel : public framework::OpKernel<T> {
auto *out = context.Output<SelectedRows>("Out"); auto *out = context.Output<SelectedRows>("Out");
out->mutable_rows()->clear(); out->mutable_rows()->clear();
bool has_data = false;
for (auto &in : inputs) {
if (in->rows().size() > 0) {
has_data = true;
break;
}
}
if (has_data) {
math::scatter::MergeAdd<DeviceContext, T> merge_add; math::scatter::MergeAdd<DeviceContext, T> merge_add;
merge_add(context.template device_context<DeviceContext>(), inputs, out); merge_add(context.template device_context<DeviceContext>(), inputs,
out);
} else {
// no data, just set a empty out tensor.
out->mutable_value()->mutable_data<T>(framework::make_ddim({0}),
context.GetPlace());
}
} else if (out_var->IsType<framework::LoDTensorArray>()) { } else if (out_var->IsType<framework::LoDTensorArray>()) {
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>(); auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) { for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册