diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index 38ab60afd91a41ceb674f8261fa5bec72bbae5f0..03d5324528930c0a16efc3c00b9f6f527289641d 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -30,7 +30,7 @@ struct StackGradFunctor { int i = idx / (n_ * post_); int which_x = idx / post_ - i * n_; int x_index = i * post_ + idx % post_; - dx_[which_x][x_index] = dy_[idx]; + if (dx_[which_x] != nullptr) dx_[which_x][x_index] = dy_[idx]; } private: @@ -95,19 +95,21 @@ class StackGradKernel : public framework::OpKernel { auto dx = ctx.MultiOutput(framework::GradVarName("X")); int axis = ctx.Attr("axis"); if (axis < 0) axis += dy->dims().size(); - int n = dy->dims()[axis]; std::vector dx_datas(n); // NOLINT + for (int i = 0; i < n; i++) { - dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); + if (dx[i] == nullptr) { + dx_datas[i] = nullptr; + } else { + dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); + } } auto dy_data = dy->data(); - int pre = 1; for (int i = 0; i < axis; ++i) pre *= dy->dims()[i]; int total_num = dy->numel(); int post = total_num / (n * pre); - auto &dev_ctx = ctx.template device_context(); auto dx_data_arr = dx_datas.data(); StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);