提交 ab3e36da 编写于 作者: Q Qiao Longfei

update MergeAdd for selected_rows_functor.cu

上级 d5c64af2
...@@ -296,6 +296,52 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -296,6 +296,52 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width); out.rows().size(), input_width);
} }
void operator()(const platform::CUDADeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) {
PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input");
auto input_width = inputs[0]->value().dims()[1];
auto input_height = inputs[0]->height();
framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set;
for (auto* input : inputs) {
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
"all input should have same "
"dimension except for the first one");
PADDLE_ENFORCE_EQ(input_height, input->height(),
"all input should have same height");
merged_row_set.insert(input->rows().begin(), input->rows().end());
}
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
out.set_rows(merge_rows);
out.set_height(input_height);
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
for (auto* input : inputs) {
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
dim3 grid1(input_rows.size(), 1);
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
input_data, input_rows.CUDAData(context.GetPlace()), out_data,
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width);
}
}
}; };
template struct MergeAdd<platform::CUDADeviceContext, float>; template struct MergeAdd<platform::CUDADeviceContext, float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册