未验证 提交 e7542a4d 编写于 作者: J Jiawei Wang 提交者: GitHub

fix stack op grad nullptr (#31962) (#32005)

上级 b934d0b8
...@@ -30,7 +30,7 @@ struct StackGradFunctor { ...@@ -30,7 +30,7 @@ struct StackGradFunctor {
int i = idx / (n_ * post_); int i = idx / (n_ * post_);
int which_x = idx / post_ - i * n_; int which_x = idx / post_ - i * n_;
int x_index = i * post_ + idx % post_; int x_index = i * post_ + idx % post_;
dx_[which_x][x_index] = dy_[idx]; if (dx_[which_x] != nullptr) dx_[which_x][x_index] = dy_[idx];
} }
private: private:
...@@ -95,19 +95,21 @@ class StackGradKernel : public framework::OpKernel<T> { ...@@ -95,19 +95,21 @@ class StackGradKernel : public framework::OpKernel<T> {
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X")); auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += dy->dims().size(); if (axis < 0) axis += dy->dims().size();
int n = dy->dims()[axis]; int n = dy->dims()[axis];
std::vector<T *> dx_datas(n); // NOLINT std::vector<T *> dx_datas(n); // NOLINT
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace()); if (dx[i] == nullptr) {
dx_datas[i] = nullptr;
} else {
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
}
} }
auto dy_data = dy->data<T>(); auto dy_data = dy->data<T>();
int pre = 1; int pre = 1;
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i]; for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
int total_num = dy->numel(); int total_num = dy->numel();
int post = total_num / (n * pre); int post = total_num / (n * pre);
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
auto dx_data_arr = dx_datas.data(); auto dx_data_arr = dx_datas.data();
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册