From ca725c82f2198e237b6f7b894c49a4693d826472 Mon Sep 17 00:00:00 2001 From: Zhang Ting <709968123@qq.com> Date: Wed, 15 Jul 2020 10:30:47 +0800 Subject: [PATCH] improve fp16 performance of slice_grad, test=develop (#25523) --- paddle/fluid/operators/slice_op.cu | 135 +---------------------------- paddle/fluid/operators/slice_op.h | 4 +- 2 files changed, 3 insertions(+), 136 deletions(-) diff --git a/paddle/fluid/operators/slice_op.cu b/paddle/fluid/operators/slice_op.cu index d6945df9e1..7493b18936 100644 --- a/paddle/fluid/operators/slice_op.cu +++ b/paddle/fluid/operators/slice_op.cu @@ -12,145 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/slice_op.h" -#include "paddle/fluid/platform/cuda_device_function.h" -#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; - -template -__global__ void Padding(const paddle::platform::float16* d_out, - const int64_t* out_dims, const int64_t* in_dims, - const int64_t* offsets, int64_t n, - paddle::platform::float16* d_in) { - int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x; - if (out_idx < n) { - int64_t out_idx_tmp = out_idx; - int64_t coords[D] = {0}; - for (int i = D - 1; i >= 0; --i) { - coords[i] = out_idx_tmp % out_dims[i]; - out_idx_tmp /= out_dims[i]; - coords[i] += offsets[i]; - } - - int64_t in_idx = 0; - for (int i = 0; i < D; ++i) { - in_idx = in_idx * in_dims[i] + coords[i]; - } - - d_in[in_idx] = d_out[out_idx]; - } -} - -template <> -class SliceGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto* d_in = ctx.Output(framework::GradVarName("Input")); - d_in->mutable_data(ctx.GetPlace()); - - auto out_dims = d_out->dims(); - auto in_dims = d_in->dims(); - int rank = out_dims.size(); - std::vector offsets(rank, 0); - auto axes = ctx.Attr>("axes"); - auto starts_int = ctx.Attr>("starts"); - std::vector starts(starts_int.begin(), starts_int.end()); - - auto list_new_starts_tensor = - ctx.MultiInput("StartsTensorList"); - - if (list_new_starts_tensor.size() > 0) { - starts = GetDataFromTensorList(list_new_starts_tensor); - } else if (ctx.HasInput("StartsTensor")) { - auto* starts_tensor = ctx.Input("StartsTensor"); - starts = GetDataFromTensor(starts_tensor); - } - - for (size_t i = 0; i < starts.size(); ++i) { - if (starts[i] < 0) { - starts[i] += in_dims[axes[i]]; - } - offsets[axes[i]] = std::max(starts[i], static_cast(0)); - } - - math::SetConstant - set_zero; - auto& dev_ctx = - ctx.template device_context(); - set_zero(dev_ctx, d_in, static_cast(0)); - - int64_t numel = d_out->numel(); - dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1); - dim3 threads(PADDLE_CUDA_NUM_THREADS); - auto stream = ctx.cuda_device_context().stream(); - const std::vector out_shape = - framework::vectorize(out_dims); - const std::vector in_shape = - framework::vectorize(in_dims); - - framework::Tensor out_dims_tensor; - framework::Tensor in_dims_tensor; - framework::Tensor offsets_tensor; - framework::TensorFromVector(out_shape, ctx.device_context(), - &out_dims_tensor); - framework::TensorFromVector(in_shape, ctx.device_context(), - &in_dims_tensor); - framework::TensorFromVector(offsets, ctx.device_context(), &offsets_tensor); - const int64_t* out_dims_ptr = out_dims_tensor.data(); - const int64_t* in_dims_ptr = in_dims_tensor.data(); - const int64_t* offsets_ptr = offsets_tensor.data(); - - switch (rank) { - case 1: - Padding<1><<>>( - d_out->data(), out_dims_ptr, in_dims_ptr, - offsets_ptr, numel, d_in->data()); - break; - case 2: - Padding<2><<>>( - d_out->data(), out_dims_ptr, in_dims_ptr, - offsets_ptr, numel, d_in->data()); - break; - case 3: - Padding<3><<>>( - d_out->data(), out_dims_ptr, in_dims_ptr, - offsets_ptr, numel, d_in->data()); - break; - case 4: - Padding<4><<>>( - d_out->data(), out_dims_ptr, in_dims_ptr, - offsets_ptr, numel, d_in->data()); - break; - case 5: - Padding<5><<>>( - d_out->data(), out_dims_ptr, in_dims_ptr, - offsets_ptr, numel, d_in->data()); - break; - case 6: - Padding<6><<>>( - d_out->data(), out_dims_ptr, in_dims_ptr, - offsets_ptr, numel, d_in->data()); - break; - } - } -}; - -} // namespace operators -} // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; + REGISTER_OP_CUDA_KERNEL( slice, ops::SliceKernel, ops::SliceKernel, diff --git a/paddle/fluid/operators/slice_op.h b/paddle/fluid/operators/slice_op.h index 39cc605f6b..ee46f4d821 100644 --- a/paddle/fluid/operators/slice_op.h +++ b/paddle/fluid/operators/slice_op.h @@ -350,7 +350,7 @@ class SliceGradKernel : public framework::OpKernel { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(context.GetPlace()); - T value = 0.0; + T value = T(0); math::SetConstant functor; for (int i = 0; i < d_in_size; ++i) { auto dim = input_array->at(i).dims(); @@ -440,7 +440,7 @@ class SliceGradKernel : public framework::OpKernel { auto d_out_t = framework::EigenTensor::From( *d_out, out_dims); - d_in_t.device(place) = d_out_t.pad(paddings, 0); + d_in_t.device(place) = d_out_t.pad(paddings, T(0)); } }; } // namespace operators -- GitLab