提交 dd78b5df 编写于 作者: Q Qiao Longfei

sum op handle empty input

上级 cbe128bb
...@@ -269,12 +269,29 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -269,12 +269,29 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) { framework::SelectedRows* output) {
PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input"); if (inputs.size() == 0) {
auto input_width = inputs[0]->value().dims()[1]; VLOG(3) << "no input! return";
auto input_height = inputs[0]->height(); return;
}
const framework::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) {
if (!in->rows().empty()) {
has_value_input = in;
break;
}
}
if (has_value_input == nullptr) {
VLOG(3) << "no input has value! just return" << std::endl;
return;
}
auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height();
framework::SelectedRows& out = *output; framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().empty()) {
continue;
}
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
"all input should have same " "all input should have same "
"dimension except for the first one"); "dimension except for the first one");
...@@ -288,7 +305,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -288,7 +305,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
for (size_t i = 0; i < merge_rows.size(); ++i) { for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i; rows_to_id[merge_rows[i]] = i;
} }
out.set_rows(merge_rows); out.set_rows(merge_rows);
out.set_height(input_height); out.set_height(input_height);
out.mutable_value()->mutable_data<T>( out.mutable_value()->mutable_data<T>(
...@@ -303,6 +319,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -303,6 +319,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().empty()) {
continue;
}
auto* input_data = input->value().data<T>(); auto* input_data = input->value().data<T>();
auto& input_rows = input->rows(); auto& input_rows = input->rows();
......
...@@ -356,9 +356,7 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { ...@@ -356,9 +356,7 @@ TEST(selected_rows_functor, cpu_merge_add_multi) {
for (size_t i = 0; i < ret_rows.size(); ++i) { for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) { for (size_t j = 0; j < row_numel; ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
std::cout << out_data[i * row_numel + j] << " ";
} }
std::cout << "\n";
} }
} }
......
...@@ -302,8 +302,6 @@ TEST(selected_rows_functor, gpu_merge_add) { ...@@ -302,8 +302,6 @@ TEST(selected_rows_functor, gpu_merge_add) {
for (size_t i = 0; i < ret_rows.size(); ++i) { for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) { for (size_t j = 0; j < row_numel; ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
std::cout << out_data[i * row_numel + j] << " ";
} }
std::cout << "\n";
} }
} }
...@@ -99,11 +99,17 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -99,11 +99,17 @@ class SumKernel : public framework::OpKernel<T> {
temp_in0.mutable_value()); temp_in0.mutable_value());
inputs.push_back(&temp_in0); inputs.push_back(&temp_in0);
for (size_t i = 1; i < in_vars.size(); ++i) { for (size_t i = 1; i < in_vars.size(); ++i) {
inputs.push_back(&in_vars[i]->Get<SelectedRows>()); auto &in = in_vars[i]->Get<SelectedRows>();
if (!in.rows().empty()) {
inputs.push_back(&in);
}
} }
} else { } else {
for (auto &in_var : in_vars) { for (auto &in_var : in_vars) {
inputs.push_back(&in_var->Get<SelectedRows>()); auto &in = in_var->Get<SelectedRows>();
if (!in.rows().empty()) {
inputs.push_back(&in_var->Get<SelectedRows>());
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册