manipulation.cu 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

C
Chen Weihang 已提交
15
#include "paddle/pten/infermeta/unary.h"
16 17
#include "paddle/pten/kernels/cuda/manipulation.h"
#include "paddle/pten/kernels/cuda/utils.h"
18
#include "paddle/pten/kernels/functions/general/manipulation.h"
19 20 21 22 23 24 25 26 27

namespace pten {

template <typename T>
void Flatten(const CUDAContext& dev_ctx,
             const DenseTensor& x,
             int start_axis,
             int stop_axis,
             DenseTensor* out) {
28
  auto out_dims = out->dims();
29
  pten::Copy(dev_ctx, x, false, out);
30
  out->Resize(out_dims);
31 32 33 34 35 36 37 38 39 40 41 42 43
}

// TODO(yuanrisheng): this kernel is for training and xshape is a Intermediate
// Output Tensor,
// is there a more flexible way to deal with this case?
template <typename T>
void FlattenWithXShape(const CUDAContext& dev_ctx,
                       const DenseTensor& x,
                       int start_axis,
                       int stop_axis,
                       DenseTensor* out,
                       DenseTensor* xshape) {
  Flatten<T>(dev_ctx, x, start_axis, stop_axis, out);
44 45 46 47 48 49 50 51 52 53 54 55
  general::SetXShape(x, xshape);
}

void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
                          const DenseTensor& x,
                          const std::vector<int>& shape,
                          DenseTensor* out) {
  auto out_meta = InferShapeFromVecValue(x.meta(), shape);
  if (&x == out) {
    LOG(INFO) << "out_meta dims:" << out_meta.dims;
    out->Resize(out_meta.dims);
    return;
56
  }
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
  pten::Copy(dev_ctx, x, false, out);
  out->Resize(out_meta.dims);
}

void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
                                    const DenseTensor& x,
                                    const std::vector<int>& shape,
                                    DenseTensor* xshape,
                                    DenseTensor* out) {
  ReshapeFromVectorVal(dev_ctx, x, shape, out);
  general::SetXShape(x, xshape);
}

void ReshapeFromDT(const CUDAContext& dev_ctx,
                   const DenseTensor& x,
                   const DenseTensor& shape,
                   DenseTensor* out) {
  auto* shape_data = shape.data<int>();
  auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel());
  ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}

void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
                             const DenseTensor& x,
                             const DenseTensor& shape,
                             DenseTensor* xshape,
                             DenseTensor* out) {
  ReshapeFromDT(dev_ctx, x, shape, out);
  general::SetXShape(x, xshape);
}

void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
                         const DenseTensor& x,
                         const std::vector<DenseTensor>& shape,
                         DenseTensor* out) {
  std::vector<int> vector_shape;
  for (auto& tensor : shape) {
    PADDLE_ENFORCE_EQ(
        tensor.dims(),
        paddle::framework::make_ddim({1}),
        paddle::platform::errors::InvalidArgument(
            "If the element type of 'shape' in ReshapeOp is Tensor, "
            "the element's shape must be [1]. But received the element's shape "
            "is [%s]",
            tensor.dims()));
    vector_shape.push_back(*tensor.data<int32_t>());
  }
  ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}

void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx,
                                   const DenseTensor& x,
                                   const std::vector<DenseTensor>& shape,
                                   DenseTensor* xshape,
                                   DenseTensor* out) {
  ReshapeFromVectorDT(dev_ctx, x, shape, out);
  general::SetXShape(x, xshape);
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
}

}  // namespace pten

// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationCUDA);

using float16 = paddle::platform::float16;
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten_contiguous_range",
                   CUDA,
                   ANY,
                   pten::Flatten,
                   float,
                   float16,
                   double,
                   uint8_t,
                   int8_t,
                   int,
                   int64_t) {}

PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
                   CUDA,
                   ANY,
                   pten::FlattenWithXShape,
                   float,
                   double,
                   uint8_t,
                   int8_t,
                   int,
                   int64_t) {}
146 147 148 149 150 151 152 153 154 155

PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
                                CUDA,
                                ANY,
                                pten::ReshapeFromVectorVal) {}

PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2.mid",
                                CUDA,
                                ANY,
                                pten::ReshapeFromVectorValWithXShape) {}