未验证 提交 95f808c8 编写于 作者: J Jiawei Wang 提交者: GitHub

fix stack op grad nullptr (#31962)

上级 57d4288a
......@@ -30,7 +30,7 @@ struct StackGradFunctor {
int i = idx / (n_ * post_);
int which_x = idx / post_ - i * n_;
int x_index = i * post_ + idx % post_;
dx_[which_x][x_index] = dy_[idx];
if (dx_[which_x] != nullptr) dx_[which_x][x_index] = dy_[idx];
}
private:
......@@ -95,19 +95,21 @@ class StackGradKernel : public framework::OpKernel<T> {
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += dy->dims().size();
int n = dy->dims()[axis];
std::vector<T *> dx_datas(n); // NOLINT
for (int i = 0; i < n; i++) {
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
if (dx[i] == nullptr) {
dx_datas[i] = nullptr;
} else {
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
}
}
auto dy_data = dy->data<T>();
int pre = 1;
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
int total_num = dy->numel();
int post = total_num / (n * pre);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
auto dx_data_arr = dx_datas.data();
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册