提交 c4c8f60b 编写于 作者: T tangwei12

sum_op selectedRows dim bug fix

上级 baff71d5
...@@ -105,8 +105,15 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -105,8 +105,15 @@ class SumKernel : public framework::OpKernel<T> {
auto &sel_row = get_selected_row(i); auto &sel_row = get_selected_row(i);
first_dim += sel_row.rows().size(); first_dim += sel_row.rows().size();
} }
auto in_dim =
framework::vectorize(get_selected_row(N - 1).value().dims()); std::vector<int64_t> in_dim;
for (int i = 0; i < N; i++) {
auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() > 0) {
in_dim = framework::vectorize(sel_row.value().dims());
break;
}
}
in_dim[0] = static_cast<int64_t>(first_dim); in_dim[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim)); out_value->Resize(framework::make_ddim(in_dim));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册