提交 748ee35c 编写于 作者: Q Qiao Longfei

sum op handle empty input update selected_rows_functor.cu

上级 dd78b5df
...@@ -305,12 +305,29 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -305,12 +305,29 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) { framework::SelectedRows* output) {
PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input"); if (inputs.size() == 0) {
auto input_width = inputs[0]->value().dims()[1]; VLOG(3) << "no input! return";
auto input_height = inputs[0]->height(); return;
}
const framework::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) {
if (!in->rows().empty()) {
has_value_input = in;
break;
}
}
if (has_value_input == nullptr) {
VLOG(3) << "no input has value! just return" << std::endl;
return;
}
auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height();
framework::SelectedRows& out = *output; framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().empty()) {
continue;
}
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
"all input should have same " "all input should have same "
"dimension except for the first one"); "dimension except for the first one");
...@@ -338,11 +355,11 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -338,11 +355,11 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
dim3 threads(block_size, 1); dim3 threads(block_size, 1);
for (auto* input : inputs) { for (auto* input : inputs) {
auto* input_data = input->value().data<T>(); if (input->rows().empty()) {
auto& input_rows = input->rows();
if (input_rows.size() == 0) {
continue; continue;
} }
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
dim3 grid1(input_rows.size(), 1); dim3 grid1(input_rows.size(), 1);
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>( MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册