slice_op.cu 6.3 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <thrust/device_vector.h>
#include "paddle/fluid/operators/math/math_function.h"
W
whs 已提交
17
#include "paddle/fluid/operators/slice_op.h"
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;

template <size_t D>
__global__ void Padding(const paddle::platform::float16* d_out,
                        const int* out_dims, const int* in_dims,
                        const int* offsets, int64_t n,
                        paddle::platform::float16* d_in) {
  int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x;
  if (out_idx < n) {
34
    int64_t out_idx_tmp = out_idx;
35 36
    int coords[D] = {0};
    for (int i = D - 1; i >= 0; --i) {
37 38
      coords[i] = out_idx_tmp % out_dims[i];
      out_idx_tmp /= out_dims[i];
39 40 41 42
      coords[i] += offsets[i];
    }

    int64_t in_idx = 0;
43 44
    for (int i = 0; i < D; ++i) {
      in_idx = in_idx * in_dims[i] + coords[i];
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    }

    d_in[in_idx] = d_out[out_idx];
  }
}

template <>
class SliceGradKernel<paddle::platform::CUDADeviceContext,
                      paddle::platform::float16>
    : public framework::OpKernel<paddle::platform::float16> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* d_in = ctx.Output<framework::Tensor>(framework::GradVarName("Input"));
    d_in->mutable_data<paddle::platform::float16>(ctx.GetPlace());

    auto out_dims = d_out->dims();
    auto in_dims = d_in->dims();
    int rank = out_dims.size();
    std::vector<int> offsets(rank, 0);
    auto axes = ctx.Attr<std::vector<int>>("axes");
    auto starts = ctx.Attr<std::vector<int>>("starts");

68 69 70 71 72 73 74 75 76 77
    auto list_new_starts_tensor =
        ctx.MultiInput<framework::Tensor>("StartsTensorList");

    if (list_new_starts_tensor.size() > 0) {
      starts = get_new_data_from_tensorlist(list_new_starts_tensor);
    } else if (ctx.HasInput("StartsTensor")) {
      auto* starts_tensor = ctx.Input<framework::Tensor>("StartsTensor");
      starts = get_new_data_from_tensor(starts_tensor);
    }

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    for (size_t i = 0; i < starts.size(); ++i) {
      if (starts[i] < 0) {
        starts[i] += in_dims[axes[i]];
      }
      offsets[axes[i]] = std::max(starts[i], 0);
    }

    math::SetConstant<paddle::platform::CUDADeviceContext,
                      paddle::platform::float16>
        set_zero;
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::CUDADeviceContext>();
    set_zero(dev_ctx, d_in, static_cast<paddle::platform::float16>(0));

    int64_t numel = d_out->numel();
93 94
    dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1);
    dim3 threads(PADDLE_CUDA_NUM_THREADS);
95 96
    auto stream = ctx.cuda_device_context().stream();

97
    auto out_shape = framework::vectorize<int>(out_dims);
98
    thrust::device_vector<int> out_dims_vec(out_shape.begin(), out_shape.end());
99
    auto in_shape = framework::vectorize<int>(in_dims);
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    thrust::device_vector<int> in_dims_vec(in_shape.begin(), in_shape.end());
    thrust::device_vector<int> offsets_vec(offsets.begin(), offsets.end());
    const int* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data());
    const int* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data());
    const int* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data());

    switch (rank) {
      case 1:
        Padding<1><<<blocks, threads, 0, stream>>>(
            d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
            offsets_ptr, numel, d_in->data<paddle::platform::float16>());
        break;
      case 2:
        Padding<2><<<blocks, threads, 0, stream>>>(
            d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
            offsets_ptr, numel, d_in->data<paddle::platform::float16>());
        break;
      case 3:
        Padding<3><<<blocks, threads, 0, stream>>>(
            d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
            offsets_ptr, numel, d_in->data<paddle::platform::float16>());
        break;
      case 4:
        Padding<4><<<blocks, threads, 0, stream>>>(
            d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
            offsets_ptr, numel, d_in->data<paddle::platform::float16>());
        break;
      case 5:
        Padding<5><<<blocks, threads, 0, stream>>>(
            d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
            offsets_ptr, numel, d_in->data<paddle::platform::float16>());
        break;
      case 6:
        Padding<6><<<blocks, threads, 0, stream>>>(
            d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
            offsets_ptr, numel, d_in->data<paddle::platform::float16>());
        break;
    }
  }
};

}  // namespace operators
}  // namespace paddle
W
whs 已提交
143 144

namespace ops = paddle::operators;
145
namespace plat = paddle::platform;
W
whs 已提交
146 147 148 149
REGISTER_OP_CUDA_KERNEL(
    slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
    ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
150 151
    ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>);
152 153 154 155 156 157

REGISTER_OP_CUDA_KERNEL(
    slice_grad,
    ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
158 159
    ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);