/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/slice_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; template __global__ void Padding(const paddle::platform::float16* d_out, const int* out_dims, const int* in_dims, const int* offsets, int64_t n, paddle::platform::float16* d_in) { int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x; if (out_idx < n) { int coords[D] = {0}; for (int i = D - 1; i >= 0; --i) { coords[i] = out_idx % out_dims[i]; out_idx /= out_dims[i]; coords[i] += offsets[i]; } int64_t in_idx = 0; for (int i = 0; i < D - 1; ++i) { in_idx += coords[i] * in_dims[i + 1]; } in_idx += coords[D - 1]; d_in[in_idx] = d_out[out_idx]; } } template <> class SliceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* d_out = ctx.Input(framework::GradVarName("Out")); auto* d_in = ctx.Output(framework::GradVarName("Input")); d_in->mutable_data(ctx.GetPlace()); auto out_dims = d_out->dims(); auto in_dims = d_in->dims(); int rank = out_dims.size(); std::vector offsets(rank, 0); auto axes = ctx.Attr>("axes"); auto starts = ctx.Attr>("starts"); for (size_t i = 0; i < starts.size(); ++i) { if (starts[i] < 0) { starts[i] += in_dims[axes[i]]; } offsets[axes[i]] = std::max(starts[i], 0); } math::SetConstant set_zero; auto& dev_ctx = ctx.template device_context(); set_zero(dev_ctx, d_in, static_cast(0)); int64_t numel = d_out->numel(); dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1, 1, 1); dim3 threads(PADDLE_CUDA_NUM_THREADS, 1, 1); auto stream = ctx.cuda_device_context().stream(); auto out_shape = framework::vectorize2int(out_dims); thrust::device_vector out_dims_vec(out_shape.begin(), out_shape.end()); auto in_shape = framework::vectorize2int(in_dims); thrust::device_vector in_dims_vec(in_shape.begin(), in_shape.end()); thrust::device_vector offsets_vec(offsets.begin(), offsets.end()); const int* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data()); const int* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data()); const int* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data()); switch (rank) { case 1: Padding<1><<>>( d_out->data(), out_dims_ptr, in_dims_ptr, offsets_ptr, numel, d_in->data()); break; case 2: Padding<2><<>>( d_out->data(), out_dims_ptr, in_dims_ptr, offsets_ptr, numel, d_in->data()); break; case 3: Padding<3><<>>( d_out->data(), out_dims_ptr, in_dims_ptr, offsets_ptr, numel, d_in->data()); break; case 4: Padding<4><<>>( d_out->data(), out_dims_ptr, in_dims_ptr, offsets_ptr, numel, d_in->data()); break; case 5: Padding<5><<>>( d_out->data(), out_dims_ptr, in_dims_ptr, offsets_ptr, numel, d_in->data()); break; case 6: Padding<6><<>>( d_out->data(), out_dims_ptr, in_dims_ptr, offsets_ptr, numel, d_in->data()); break; } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( slice, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel); REGISTER_OP_CUDA_KERNEL( slice_grad, ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel);