提交 920a9609 编写于 作者: Q Qiao Longfei

optimize merge add if input rows of all selected rows is not duplicated

上级 a06f4b2b
...@@ -296,6 +296,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -296,6 +296,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
auto input_height = has_value_input->height(); auto input_height = has_value_input->height();
framework::SelectedRows& out = *output; framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
size_t row_num = 0;
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().size() == 0) { if (input->rows().size() == 0) {
continue; continue;
...@@ -305,42 +306,71 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -305,42 +306,71 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
"dimension except for the first one"); "dimension except for the first one");
PADDLE_ENFORCE_EQ(input_height, input->height(), PADDLE_ENFORCE_EQ(input_height, input->height(),
"all input should have same height"); "all input should have same height");
row_num += input->rows().size();
merged_row_set.insert(input->rows().begin(), input->rows().end()); merged_row_set.insert(input->rows().begin(), input->rows().end());
} }
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
if (sorted_result) {
std::sort(merge_rows.begin(), merge_rows.end());
}
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
}
out.set_rows(merge_rows);
out.set_height(input_height); out.set_height(input_height);
out.mutable_value()->mutable_data<T>( out.mutable_value()->mutable_data<T>(
framework::make_ddim( framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}), {static_cast<int64_t>(merged_row_set.size()), input_width}),
context.GetPlace()); context.GetPlace());
auto* out_data = out.mutable_value()->data<T>();
math::SetConstant<platform::CPUDeviceContext, T> constant_functor; if (merged_row_set.size() == row_num && !sorted_result) {
constant_functor(context, out.mutable_value(), 0.0); // no duplicated ids, just concat the result together
std::vector<int64_t> merge_rows;
merge_rows.reserve(row_num);
// concat rows
for (auto* in : inputs) {
merge_rows.insert(merge_rows.end(), in->rows().begin(),
in->rows().end());
}
out.set_rows(merge_rows);
auto in_place = inputs[0]->place();
auto out_place = out.place();
int64_t copied_numel = 0;
for (auto* in : inputs) {
auto* in_data = in->value().data<T>();
auto in_numel = in->value().numel();
memory::Copy(boost::get<platform::CPUPlace>(out_place),
out_data + copied_numel,
boost::get<platform::CPUPlace>(in_place), in_data,
in_numel * sizeof(T));
copied_numel += in_numel;
}
} else {
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
auto* out_data = out.mutable_value()->data<T>(); if (sorted_result) {
std::sort(merge_rows.begin(), merge_rows.end());
}
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); out.set_rows(merge_rows);
for (auto* input : inputs) {
if (input->rows().size() == 0) { math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
continue; constant_functor(context, out.mutable_value(), 0.0);
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
} }
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows(); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (auto* input : inputs) {
for (size_t i = 0; i < input_rows.size(); i++) { if (input->rows().size() == 0) {
size_t out_i = rows_to_id[input_rows[i]]; continue;
elementwise_add_to<platform::CPUDeviceContext, T>( }
context, &blas, static_cast<size_t>(input_width), auto* input_data = input->value().data<T>();
&input_data[i * input_width], &out_data[out_i * input_width]); auto& input_rows = input->rows();
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
elementwise_add_to<platform::CPUDeviceContext, T>(
context, &blas, static_cast<size_t>(input_width),
&input_data[i * input_width], &out_data[out_i * input_width]);
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册