diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index 56a12852a91e87ddbe03d6ba96d4b10b1a451cec..f1692ae9563fd2551ae873472ae1a5b34132b6c4 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -72,6 +72,25 @@ class StackOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +struct StackFunctor { + HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post) + : x_(x), y_(y), n_(n), post_(post) {} + + HOSTDEVICE void operator()(int idx) { + int i = idx / (n_ * post_); + int which_x = idx / post_ - i * n_; + int x_index = i * post_ + idx % post_; + y_[idx] = x_[which_x][x_index]; + } + + private: + VecXType x_; + T *y_; + int n_; + int post_; +}; + template struct StackGradFunctor { HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post) @@ -91,6 +110,14 @@ struct StackGradFunctor { int post_; }; +template +static inline void StackFunctorForRange(const DeviceContext &ctx, + const VecXType &x, T *y, int total_num, + int n, int post) { + platform::ForRange for_range(ctx, total_num); + for_range(StackFunctor(x, y, n, post)); +} + template static inline void StackGradFunctorForRange(const DeviceContext &ctx, const VecDxType &dx, const T *dy,