cross_kernel.cu 4.9 KB
Newer Older
0
0x45f 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/cross_kernel.h"

0
0x45f 已提交
17
#include "paddle/phi/backends/gpu/gpu_context.h"
Z
zhangbopd 已提交
18 19
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
0
0x45f 已提交
20
#include "paddle/phi/core/kernel_registry.h"
Z
zhangbopd 已提交
21 22 23 24 25 26 27 28 29 30
#include "paddle/phi/kernels/funcs/reduce_function.h"

namespace phi {

template <typename T>
__global__ void Cross(const T* x,
                      const T* y,
                      T* out,
                      const int stride,
                      const int N,
31
                      phi::funcs::IndexCalculator index_calculator) {
Z
zhangbopd 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
  CUDA_KERNEL_LOOP(i, N) {
    int offset = index_calculator(i);

    auto pos0 = offset + 0 * stride;
    auto pos1 = offset + 1 * stride;
    auto pos2 = offset + 2 * stride;

    out[pos0] = x[pos1] * y[pos2] - x[pos2] * y[pos1];
    out[pos1] = x[pos2] * y[pos0] - x[pos0] * y[pos2];
    out[pos2] = x[pos0] * y[pos1] - x[pos1] * y[pos0];
  }
}

template <typename T, typename Context>
void CrossKernel(const Context& dev_ctx,
                 const DenseTensor& x,
                 const DenseTensor& y,
                 int axis,
                 DenseTensor* out) {
  auto& input_x = x;
  auto& input_y = y;
  auto* output = out;
  int dim = axis;

  auto input_x_dims = input_x.dims();
  if (dim != DDim::kMaxRank) {
    PADDLE_ENFORCE_EQ(
        dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()),
        true,
        phi::errors::OutOfRange(
            "Attr(dim) is out of range, It's expected "
            "to be in range of [-%d, %d]. But received Attr(dim) = %d.",
            input_x_dims.size(),
            input_x_dims.size() - 1,
            dim));
    if (dim < 0) {
      dim += input_x_dims.size();
    }

    PADDLE_ENFORCE_EQ(
        input_x_dims[dim] == 3,
        true,
        phi::errors::InvalidArgument(
            "Input(X/Y).dims[dim] must be equal to 3. But received: "
            "Input(X/Y).dims[dim] = [%d].",
            input_x_dims[dim]));
  } else {
    for (auto i = 0; i < input_x_dims.size(); i++) {
      if (input_x_dims[i] == 3) {
        dim = i;
        break;
      }
    }
    PADDLE_ENFORCE_EQ(dim == DDim::kMaxRank,
                      false,
                      phi::errors::InvalidArgument(
                          "There must be at least one dimension 'd' so that "
                          "Input(X/Y).dims()[d] is equal to 3. "
                          "But received: Input(X/Y).dims() == [%s].",
                          input_x_dims));
  }

  std::vector<int> cal_dims;
  std::vector<int> left_strides;
  std::vector<int> full_strides;
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  std::vector<int> merged_dims;

  for (int i = 0; i < dim; i++) {
    if (i == 0) {
      merged_dims.push_back(input_x_dims[i]);
    } else {
      merged_dims[0] *= input_x_dims[i];
    }
  }
  int merge_axis = merged_dims.size();
  merged_dims.push_back(input_x_dims[dim]);
  for (int i = dim + 1; i < input_x_dims.size(); i++) {
    if (i == dim + 1) {
      merged_dims.push_back(input_x_dims[i]);
    } else {
      merged_dims[merge_axis + 1] *= input_x_dims[i];
    }
  }
Z
zhangbopd 已提交
115

116 117 118 119 120
  int full_dim = 1;
  for (int i = 0; i < merged_dims.size(); i++) {
    full_strides.insert(full_strides.begin(), full_dim);
    full_dim *= merged_dims[merged_dims.size() - i - 1];
    if (i == merge_axis) {
Z
zhangbopd 已提交
121 122 123
      continue;
    }
    cal_dims.push_back(i);
124 125 126 127 128 129 130 131
  }
  int left_dim = 1;
  for (int i = merged_dims.size() - 1; i >= 0; i--) {
    if (i == merge_axis) {
      continue;
    }
    left_strides.insert(left_strides.begin(), left_dim);
    left_dim *= merged_dims[i];
Z
zhangbopd 已提交
132 133 134 135 136
  }

  const auto* input_x_data = input_x.data<T>();
  const auto* input_y_data = input_y.data<T>();
  auto* out_data = dev_ctx.template Alloc<T>(out);
137 138
  auto index_calculator = phi::funcs::IndexCalculator(
      merged_dims.size() - 1, cal_dims, left_strides, full_strides);
Z
zhangbopd 已提交
139 140 141 142 143 144 145 146 147 148 149

  int64_t numel = x.numel();
  backends::gpu::GpuLaunchConfig config =
      backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);

  Cross<<<config.block_per_grid,
          config.thread_per_block,
          0,
          dev_ctx.stream()>>>(input_x_data,
                              input_y_data,
                              out_data,
150
                              full_strides[merge_axis],
Z
zhangbopd 已提交
151 152 153 154
                              numel / 3,
                              index_calculator);
}
}  // namespace phi
0
0x45f 已提交
155 156 157

PD_REGISTER_KERNEL(
    cross, GPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {}