// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { template struct StackGradFunctor { HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post) : dx_(dx), dy_(dy), n_(n), post_(post) {} HOSTDEVICE void operator()(int idx) { int i = idx / (n_ * post_); int which_x = idx / post_ - i * n_; int x_index = i * post_ + idx % post_; if (dx_[which_x] != nullptr) dx_[which_x][x_index] = dy_[idx]; } private: VecDxType dx_; const T *dy_; int n_; int post_; }; template static inline void StackGradFunctorForRange(const DeviceContext &ctx, const VecDxType &dx, const T *dy, int total_num, int n, int post) { platform::ForRange for_range(ctx, total_num); for_range(StackGradFunctor(dx, dy, n, post)); } template class StackKernel : public framework::OpKernel { using Tensor = framework::LoDTensor; public: void Compute(const framework::ExecutionContext &ctx) const override { auto x = ctx.MultiInput("X"); auto *y = ctx.Output("Y"); int axis = ctx.Attr("axis"); if (axis < 0) axis += (x[0]->dims().size() + 1); int n = static_cast(x.size()); auto *y_data = y->mutable_data(ctx.GetPlace()); std::vector x_datas(n); for (int i = 0; i < n; i++) x_datas[i] = x[i]->data(); int pre = 1, post = 1; auto &dim = x[0]->dims(); for (auto i = 0; i < axis; ++i) pre *= dim[i]; for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; auto x_data_arr = x_datas.data(); size_t x_offset = 0; size_t y_offset = 0; for (int i = 0; i < pre; i++) { for (int j = 0; j < n; j++) { std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset, post * sizeof(T)); y_offset += post; } x_offset += post; } } }; template class StackGradKernel : public framework::OpKernel { using Tensor = framework::LoDTensor; public: void Compute(const framework::ExecutionContext &ctx) const override { auto *dy = ctx.Input(framework::GradVarName("Y")); auto dx = ctx.MultiOutput(framework::GradVarName("X")); int axis = ctx.Attr("axis"); if (axis < 0) axis += dy->dims().size(); int n = dy->dims()[axis]; std::vector dx_datas(n); // NOLINT for (int i = 0; i < n; i++) { if (dx[i] == nullptr) { dx_datas[i] = nullptr; } else { dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); } } auto dy_data = dy->data(); int pre = 1; for (int i = 0; i < axis; ++i) pre *= dy->dims()[i]; int total_num = dy->numel(); int post = total_num / (n * pre); auto &dev_ctx = ctx.template device_context(); auto dx_data_arr = dx_datas.data(); StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); } }; } // namespace operators } // namespace paddle