From 4ae23cc3c5e37de8685863190ce29f36f69caefd Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 15 Mar 2019 10:21:56 +0800 Subject: [PATCH] Impl fp16 compute kernel for slice_op (#16206) * Impl fp16 compute kernel for slice_op test=develop * Use data() to replace mutable_data() --- paddle/fluid/operators/slice_op.cu | 124 +++++++++++++++++- .../fluid/tests/unittests/test_slice_op.py | 24 ++++ 2 files changed, 146 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/slice_op.cu b/paddle/fluid/operators/slice_op.cu index 5efecb78d1..1af57b89a3 100644 --- a/paddle/fluid/operators/slice_op.cu +++ b/paddle/fluid/operators/slice_op.cu @@ -12,18 +12,138 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/slice_op.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void Padding(const paddle::platform::float16* d_out, + const int* out_dims, const int* in_dims, + const int* offsets, int64_t n, + paddle::platform::float16* d_in) { + int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x; + if (out_idx < n) { + int coords[D] = {0}; + for (int i = D - 1; i >= 0; --i) { + coords[i] = out_idx % out_dims[i]; + out_idx /= out_dims[i]; + coords[i] += offsets[i]; + } + + int64_t in_idx = 0; + for (int i = 0; i < D - 1; ++i) { + in_idx += coords[i] * in_dims[i + 1]; + } + in_idx += coords[D - 1]; + + d_in[in_idx] = d_out[out_idx]; + } +} + +template <> +class SliceGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* d_in = ctx.Output(framework::GradVarName("Input")); + d_in->mutable_data(ctx.GetPlace()); + + auto out_dims = d_out->dims(); + auto in_dims = d_in->dims(); + int rank = out_dims.size(); + std::vector offsets(rank, 0); + auto axes = ctx.Attr>("axes"); + auto starts = ctx.Attr>("starts"); + + for (size_t i = 0; i < starts.size(); ++i) { + if (starts[i] < 0) { + starts[i] += in_dims[axes[i]]; + } + offsets[axes[i]] = std::max(starts[i], 0); + } + + math::SetConstant + set_zero; + auto& dev_ctx = + ctx.template device_context(); + set_zero(dev_ctx, d_in, static_cast(0)); + + int64_t numel = d_out->numel(); + dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1, 1, 1); + dim3 threads(PADDLE_CUDA_NUM_THREADS, 1, 1); + auto stream = ctx.cuda_device_context().stream(); + + auto out_shape = framework::vectorize2int(out_dims); + thrust::device_vector out_dims_vec(out_shape.begin(), out_shape.end()); + auto in_shape = framework::vectorize2int(in_dims); + thrust::device_vector in_dims_vec(in_shape.begin(), in_shape.end()); + thrust::device_vector offsets_vec(offsets.begin(), offsets.end()); + const int* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data()); + const int* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data()); + const int* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data()); + + switch (rank) { + case 1: + Padding<1><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 2: + Padding<2><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 3: + Padding<3><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 4: + Padding<4><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 5: + Padding<5><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + case 6: + Padding<6><<>>( + d_out->data(), out_dims_ptr, in_dims_ptr, + offsets_ptr, numel, d_in->data()); + break; + } + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( slice, ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, - ops::SliceKernel); + ops::SliceKernel, + ops::SliceKernel); REGISTER_OP_CUDA_KERNEL( slice_grad, ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, - ops::SliceGradKernel); + ops::SliceGradKernel, + ops::SliceGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 4e6ed3a74b..5fdabbabed 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest @@ -63,5 +64,28 @@ class TestCase2(TestSliceOp): self.out = self.input[-3:3, 0:100, :, 2:-1] +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFP16(TestSliceOp): + def config(self): + self.dtype = "float16" + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [-3, 0, 2] + self.ends = [3, 100, -1] + self.axes = [0, 1, 3] + self.out = self.input[-3:3, 0:100, :, 2:-1] + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-5) + + def test_check_grad_normal(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, ['Input'], 'Out', max_relative_error=0.006) + + if __name__ == '__main__': unittest.main() -- GitLab