提交 14f5a408 编写于 作者: Q Qiao Longfei

fix unit test

上级 02259575
...@@ -267,10 +267,15 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -267,10 +267,15 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output) { framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
framework::Vector<int64_t> input_rows(input.rows()); framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) {
return;
}
framework::SelectedRows& out = *output;
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
framework::Vector<int64_t> merge_rows(merge_rows_cpu);
auto input_width = input.value().dims()[1]; auto input_width = input.value().dims()[1];
...@@ -313,8 +318,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -313,8 +318,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
"all input should have same height"); "all input should have same height");
merged_row_set.insert(input->rows().begin(), input->rows().end()); merged_row_set.insert(input->rows().begin(), input->rows().end());
} }
std::vector<int64_t> merge_rows(merged_row_set.begin(), std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
merged_row_set.end()); merged_row_set.end());
framework::Vector<int64_t> merge_rows(merge_rows_cpu);
out.set_rows(merge_rows); out.set_rows(merge_rows);
out.set_height(input_height); out.set_height(input_height);
...@@ -334,6 +340,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -334,6 +340,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
for (auto* input : inputs) { for (auto* input : inputs) {
auto* input_data = input->value().data<T>(); auto* input_data = input->value().data<T>();
auto& input_rows = input->rows(); auto& input_rows = input->rows();
if (input_rows.size() == 0) {
continue;
}
dim3 grid1(input_rows.size(), 1); dim3 grid1(input_rows.size(), 1);
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>( MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册