From e7542a4dc22b922d2733d53cc1f474e12a5dce1f Mon Sep 17 00:00:00 2001 From: Jiawei Wang Date: Thu, 1 Apr 2021 20:43:33 +0800 Subject: [PATCH] fix stack op grad nullptr (#31962) (#32005) --- paddle/fluid/operators/stack_op.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index 38ab60afd91..03d53245289 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -30,7 +30,7 @@ struct StackGradFunctor { int i = idx / (n_ * post_); int which_x = idx / post_ - i * n_; int x_index = i * post_ + idx % post_; - dx_[which_x][x_index] = dy_[idx]; + if (dx_[which_x] != nullptr) dx_[which_x][x_index] = dy_[idx]; } private: @@ -95,19 +95,21 @@ class StackGradKernel : public framework::OpKernel { auto dx = ctx.MultiOutput(framework::GradVarName("X")); int axis = ctx.Attr("axis"); if (axis < 0) axis += dy->dims().size(); - int n = dy->dims()[axis]; std::vector dx_datas(n); // NOLINT + for (int i = 0; i < n; i++) { - dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); + if (dx[i] == nullptr) { + dx_datas[i] = nullptr; + } else { + dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); + } } auto dy_data = dy->data(); - int pre = 1; for (int i = 0; i < axis; ++i) pre *= dy->dims()[i]; int total_num = dy->numel(); int post = total_num / (n * pre); - auto &dev_ctx = ctx.template device_context(); auto dx_data_arr = dx_datas.data(); StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); -- GitLab