提交 b4f52b01 编写于 作者: T tangwei12

bug fix when all inputs are empty

上级 3efac174
...@@ -187,6 +187,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -187,6 +187,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (in_dim.empty()) { if (in_dim.empty()) {
VLOG(3) << "WARNING: all the inputs are empty"; VLOG(3) << "WARNING: all the inputs are empty";
in_dim = framework::vectorize(get_selected_row(N - 1).value().dims());
} else {
in_dim[0] = static_cast<int64_t>(first_dim);
} }
in_dim[0] = static_cast<int64_t>(first_dim); in_dim[0] = static_cast<int64_t>(first_dim);
......
...@@ -116,9 +116,10 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -116,9 +116,10 @@ class SumKernel : public framework::OpKernel<T> {
} }
if (in_dim.empty()) { if (in_dim.empty()) {
VLOG(3) << "WARNING: all the inputs are empty"; VLOG(3) << "WARNING: all the inputs are empty";
} in_dim = framework::vectorize(get_selected_row(N - 1).value().dims());
} else {
in_dim[0] = static_cast<int64_t>(first_dim); in_dim[0] = static_cast<int64_t>(first_dim);
}
out_value->Resize(framework::make_ddim(in_dim)); out_value->Resize(framework::make_ddim(in_dim));
out_value->mutable_data<T>(context.GetPlace()); out_value->mutable_data<T>(context.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册